aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2023-10-18 14:01:51 +0200
committerHåvard Pettersen <havardpe@yahooinc.com>2023-10-18 14:17:23 +0200
commita37a53cec5aa818ea005e4cefb3349eb9fe0b769 (patch)
treea244ecae15d11798829cdac72620f99ca31a709f
parentc12bbb3c36628eeaa598d3567c57c1d4a81b8ffd (diff)
support default values for query featureshavardpe/extract-default-query-feature-values
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java13
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java16
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java13
-rw-r--r--container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java50
4 files changed, 50 insertions, 42 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
index 829d0c268e5..91acc883803 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
@@ -13,10 +13,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.data.access.helpers.MatchFeatureData;
import com.yahoo.data.access.helpers.MatchFeatureFilter;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-import java.util.Optional;
+import java.util.*;
import java.util.function.Supplier;
import java.util.logging.Logger;
@@ -52,12 +49,12 @@ public class GlobalPhaseRanker {
static void rerankHitsImpl(GlobalPhaseSetup setup, Query query, Result result) {
var mainSpec = setup.globalPhaseEvalSpec;
- var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), query);
+ var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), setup.defaultValues, query);
int rerankCount = resolveRerankCount(setup, query);
var normalizers = new ArrayList<NormalizerContext>();
for (var nSetup : setup.normalizers) {
var normSpec = nSetup.inputEvalSpec();
- var normEvalSrc = withQueryPrep(normSpec.evalSource(), normSpec.fromQuery(), query);
+ var normEvalSrc = withQueryPrep(normSpec.evalSource(), normSpec.fromQuery(), setup.defaultValues, query);
normalizers.add(new NormalizerContext(nSetup.name(), nSetup.supplier().get(), normEvalSrc, normSpec.fromMF()));
}
var rescorer = new HitRescorer(mainSrc, mainSpec.fromMF(), normalizers);
@@ -73,8 +70,8 @@ public class GlobalPhaseRanker {
}
}
- static Supplier<Evaluator> withQueryPrep(Supplier<Evaluator> evalSource, List<String> queryFeatures, Query query) {
- var prepared = PreparedInput.findFromQuery(query, queryFeatures);
+ static Supplier<Evaluator> withQueryPrep(Supplier<Evaluator> evalSource, List<String> queryFeatures, Map<String, Tensor> defaultValues, Query query) {
+ var prepared = PreparedInput.findFromQuery(query, queryFeatures, defaultValues);
Supplier<Evaluator> supplier = () -> {
var evaluator = evalSource.get();
for (var entry : prepared) {
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java
index 31a676e4c8e..e5cd09d3a18 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java
@@ -3,15 +3,10 @@ package com.yahoo.search.ranking;
import ai.vespa.models.evaluation.FunctionEvaluator;
+import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.config.search.RankProfilesConfig;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-import java.util.Map;
-import java.util.HashMap;
+import java.util.*;
import java.util.function.Supplier;
class GlobalPhaseSetup {
@@ -20,16 +15,19 @@ class GlobalPhaseSetup {
final int rerankCount;
final Collection<String> matchFeaturesToHide;
final List<NormalizerSetup> normalizers;
+ final Map<String, Tensor> defaultValues;
GlobalPhaseSetup(FunEvalSpec globalPhaseEvalSpec,
final int rerankCount,
Collection<String> matchFeaturesToHide,
- List<NormalizerSetup> normalizers)
+ List<NormalizerSetup> normalizers,
+ Map<String, Tensor> defaultValues)
{
this.globalPhaseEvalSpec = globalPhaseEvalSpec;
this.rerankCount = rerankCount;
this.matchFeaturesToHide = matchFeaturesToHide;
this.normalizers = normalizers;
+ this.defaultValues = defaultValues;
}
static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) {
@@ -106,7 +104,7 @@ class GlobalPhaseSetup {
}
Supplier<Evaluator> supplier = SimpleEvaluator.wrap(functionEvaluatorSource);
var gfun = new FunEvalSpec(supplier, fromQuery, fromMF);
- return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers);
+ return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers, Collections.emptyMap());
}
return null;
}
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java
index 5491724cc08..914635fef59 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java
@@ -13,16 +13,13 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.data.access.helpers.MatchFeatureData;
import com.yahoo.data.access.helpers.MatchFeatureFilter;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-import java.util.Optional;
+import java.util.*;
import java.util.function.Supplier;
import java.util.logging.Logger;
record PreparedInput(String name, Tensor value) {
- static List<PreparedInput> findFromQuery(Query query, Collection<String> queryFeatures) {
+ static List<PreparedInput> findFromQuery(Query query, Collection<String> queryFeatures, Map<String, Tensor> defaultValues) {
List<PreparedInput> result = new ArrayList<>();
var ranking = query.getRanking();
var rankFeatures = ranking.getFeatures();
@@ -36,6 +33,12 @@ record PreparedInput(String name, Tensor value) {
feature = rankFeatures.getTensor(needed);
}
if (feature.isEmpty()) {
+ var t = defaultValues.get(needed);
+ if (t != null) {
+ feature = Optional.of(t);
+ }
+ }
+ if (feature.isEmpty()) {
throw new IllegalArgumentException("missing query feature: " + queryFeatureName);
}
result.add(new PreparedInput(needed, feature.get()));
diff --git a/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java
index ce9ac377908..f55130c0c93 100644
--- a/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java
+++ b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java
@@ -60,17 +60,20 @@ public class GlobalPhaseRerankHitsImplTest {
static NormalizerSetup makeNormalizer(String name, List<Double> expected, FunEvalSpec evalSpec) {
return new NormalizerSetup(name, () -> new ExpectingNormalizer(expected), evalSpec);
}
- static GlobalPhaseSetup makeFullSetup(FunEvalSpec mainSpec, int rerankCount,
- List<String> hiddenMF, List<NormalizerSetup> normalizers)
- {
- return new GlobalPhaseSetup(mainSpec, rerankCount, hiddenMF, normalizers);
- }
- static GlobalPhaseSetup makeSimpleSetup(FunEvalSpec mainSpec, int rerankCount) {
- return makeFullSetup(mainSpec, rerankCount, Collections.emptyList(), Collections.emptyList());
- }
- static GlobalPhaseSetup makeNormSetup(FunEvalSpec mainSpec, List<NormalizerSetup> normalizers) {
- return makeFullSetup(mainSpec, 100, Collections.emptyList(), normalizers);
- }
+ static class SetupBuilder {
+ FunEvalSpec mainSpec = makeConstSpec(0.0);
+ int rerankCount = 100;
+ List<String> hiddenMF = new ArrayList<>();
+ List<NormalizerSetup> normalizers = new ArrayList<>();
+ Map<String, Tensor> defaultValues = new HashMap<>();
+ SetupBuilder eval(FunEvalSpec spec) { mainSpec = spec; return this; }
+ SetupBuilder rerank(int value) { rerankCount = value; return this; }
+ SetupBuilder hide(String mf) { hiddenMF.add(mf); return this; }
+ SetupBuilder addNormalizer(NormalizerSetup normalizer) { normalizers.add(normalizer); return this; }
+ SetupBuilder addDefault(String name, Tensor value) { defaultValues.put(name, value); return this; }
+ GlobalPhaseSetup build() { return new GlobalPhaseSetup(mainSpec, rerankCount, hiddenMF, normalizers, defaultValues); }
+ }
+ static SetupBuilder setup() { return new SetupBuilder(); }
static record NamedValue(String name, double value) {}
NamedValue value(String name, double value) {
return new NamedValue(name, value);
@@ -167,7 +170,7 @@ public class GlobalPhaseRerankHitsImplTest {
}
}
@Test void partialRerankWithRescaling() {
- var setup = makeSimpleSetup(makeConstSpec(3.0), 2);
+ var setup = setup().rerank(2).eval(makeConstSpec(3.0)).build();
var query = makeQuery(Collections.emptyList());
var result = makeResult(query, List.of(hit("a", 3), hit("b", 4), hit("c", 5), hit("d", 6)));
var expect = Expect.make(List.of(hit("a", 1), hit("b", 2), hit("c", 3), hit("d", 3)));
@@ -175,8 +178,7 @@ public class GlobalPhaseRerankHitsImplTest {
expect.verifyScores(result);
}
@Test void matchFeaturesCanBePartiallyHidden() {
- var setup = makeFullSetup(makeSumSpec(Collections.emptyList(), List.of("public_value", "private_value")), 2,
- List.of("private_value"), Collections.emptyList());
+ var setup = setup().eval(makeSumSpec(Collections.emptyList(), List.of("public_value", "private_value"))).hide("private_value").build();
var query = makeQuery(Collections.emptyList());
var factory = new HitFactory(List.of("public_value", "private_value"));
var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("public_value", 2), value("private_value", 3))),
@@ -188,8 +190,7 @@ public class GlobalPhaseRerankHitsImplTest {
verifyDoesNotHaveMF(result, "private_value");
}
@Test void matchFeaturesCanBeRemoved() {
- var setup = makeFullSetup(makeSumSpec(Collections.emptyList(), List.of("private_value")), 2,
- List.of("private_value"), Collections.emptyList());
+ var setup = setup().eval(makeSumSpec(Collections.emptyList(), List.of("private_value"))).hide("private_value").build();
var query = makeQuery(Collections.emptyList());
var factory = new HitFactory(List.of("private_value"));
var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("private_value", 3))),
@@ -200,7 +201,7 @@ public class GlobalPhaseRerankHitsImplTest {
verifyDoesNotHaveMatchFeaturesField(result);
}
@Test void queryFeaturesCanBeUsed() {
- var setup = makeSimpleSetup(makeSumSpec(List.of("foo"), List.of("bar")), 2);
+ var setup = setup().eval(makeSumSpec(List.of("foo"), List.of("bar"))).build();
var query = makeQuery(List.of(value("query(foo)", 7)));
var factory = new HitFactory(List.of("bar"));
var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("bar", 2))),
@@ -211,7 +212,7 @@ public class GlobalPhaseRerankHitsImplTest {
verifyHasMF(result, "bar");
}
@Test void queryFeaturesCanBeUsedWhenPrepared() {
- var setup = makeSimpleSetup(makeSumSpec(List.of("foo"), List.of("bar")), 2);
+ var setup = setup().eval(makeSumSpec(List.of("foo"), List.of("bar"))).build();
var query = makeQueryWithPrepare(List.of(value("query(foo)", 7)));
var factory = new HitFactory(List.of("bar"));
var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("bar", 2))),
@@ -221,9 +222,18 @@ public class GlobalPhaseRerankHitsImplTest {
expect.verifyScores(result);
verifyHasMF(result, "bar");
}
+ @Test void queryFeaturesCanBeDefaultValues() {
+ var setup = setup().eval(makeSumSpec(List.of("foo", "bar"), Collections.emptyList()))
+ .addDefault("query(bar)", Tensor.from(5.0)).build();
+ var query = makeQuery(List.of(value("query(foo)", 7)));
+ var result = makeResult(query, List.of(hit("a", 1)));
+ var expect = Expect.make(List.of(hit("a", 12)));
+ GlobalPhaseRanker.rerankHitsImpl(setup, query, result);
+ expect.verifyScores(result);
+ }
@Test void withNormalizer() {
- var setup = makeNormSetup(makeSumSpec(Collections.emptyList(), List.of("bar")),
- List.of(makeNormalizer("foo", List.of(115.0, 65.0, 55.0, 45.0, 15.0), makeSumSpec(List.of("x"), List.of("bar")))));
+ var setup = setup().eval(makeSumSpec(Collections.emptyList(), List.of("bar")))
+ .addNormalizer(makeNormalizer("foo", List.of(115.0, 65.0, 55.0, 45.0, 15.0), makeSumSpec(List.of("x"), List.of("bar")))).build();
var query = makeQuery(List.of(value("query(x)", 5)));
var factory = new HitFactory(List.of("bar"));
var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("bar", 10))),