From a37a53cec5aa818ea005e4cefb3349eb9fe0b769 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Wed, 18 Oct 2023 14:01:51 +0200 Subject: support default values for query features --- .../yahoo/search/ranking/GlobalPhaseRanker.java | 13 +++--- .../com/yahoo/search/ranking/GlobalPhaseSetup.java | 16 +++---- .../com/yahoo/search/ranking/PreparedInput.java | 13 +++--- .../ranking/GlobalPhaseRerankHitsImplTest.java | 50 +++++++++++++--------- 4 files changed, 50 insertions(+), 42 deletions(-) diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java index 829d0c268e5..91acc883803 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java @@ -13,10 +13,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.data.access.helpers.MatchFeatureData; import com.yahoo.data.access.helpers.MatchFeatureFilter; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.function.Supplier; import java.util.logging.Logger; @@ -52,12 +49,12 @@ public class GlobalPhaseRanker { static void rerankHitsImpl(GlobalPhaseSetup setup, Query query, Result result) { var mainSpec = setup.globalPhaseEvalSpec; - var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), query); + var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), setup.defaultValues, query); int rerankCount = resolveRerankCount(setup, query); var normalizers = new ArrayList(); for (var nSetup : setup.normalizers) { var normSpec = nSetup.inputEvalSpec(); - var normEvalSrc = withQueryPrep(normSpec.evalSource(), normSpec.fromQuery(), query); + var normEvalSrc = withQueryPrep(normSpec.evalSource(), normSpec.fromQuery(), setup.defaultValues, query); normalizers.add(new NormalizerContext(nSetup.name(), nSetup.supplier().get(), normEvalSrc, normSpec.fromMF())); } var rescorer = new HitRescorer(mainSrc, mainSpec.fromMF(), normalizers); @@ -73,8 +70,8 @@ public class GlobalPhaseRanker { } } - static Supplier withQueryPrep(Supplier evalSource, List queryFeatures, Query query) { - var prepared = PreparedInput.findFromQuery(query, queryFeatures); + static Supplier withQueryPrep(Supplier evalSource, List queryFeatures, Map defaultValues, Query query) { + var prepared = PreparedInput.findFromQuery(query, queryFeatures, defaultValues); Supplier supplier = () -> { var evaluator = evalSource.get(); for (var entry : prepared) { diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java index 31a676e4c8e..e5cd09d3a18 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java @@ -3,15 +3,10 @@ package com.yahoo.search.ranking; import ai.vespa.models.evaluation.FunctionEvaluator; +import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.Map; -import java.util.HashMap; +import java.util.*; import java.util.function.Supplier; class GlobalPhaseSetup { @@ -20,16 +15,19 @@ class GlobalPhaseSetup { final int rerankCount; final Collection matchFeaturesToHide; final List normalizers; + final Map defaultValues; GlobalPhaseSetup(FunEvalSpec globalPhaseEvalSpec, final int rerankCount, Collection matchFeaturesToHide, - List normalizers) + List normalizers, + Map defaultValues) { this.globalPhaseEvalSpec = globalPhaseEvalSpec; this.rerankCount = rerankCount; this.matchFeaturesToHide = matchFeaturesToHide; this.normalizers = normalizers; + this.defaultValues = defaultValues; } static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) { @@ -106,7 +104,7 @@ class GlobalPhaseSetup { } Supplier supplier = SimpleEvaluator.wrap(functionEvaluatorSource); var gfun = new FunEvalSpec(supplier, fromQuery, fromMF); - return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers); + return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers, Collections.emptyMap()); } return null; } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java index 5491724cc08..914635fef59 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java @@ -13,16 +13,13 @@ import com.yahoo.tensor.Tensor; import com.yahoo.data.access.helpers.MatchFeatureData; import com.yahoo.data.access.helpers.MatchFeatureFilter; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.function.Supplier; import java.util.logging.Logger; record PreparedInput(String name, Tensor value) { - static List findFromQuery(Query query, Collection queryFeatures) { + static List findFromQuery(Query query, Collection queryFeatures, Map defaultValues) { List result = new ArrayList<>(); var ranking = query.getRanking(); var rankFeatures = ranking.getFeatures(); @@ -35,6 +32,12 @@ record PreparedInput(String name, Tensor value) { // searchers are recommended to place query features here: feature = rankFeatures.getTensor(needed); } + if (feature.isEmpty()) { + var t = defaultValues.get(needed); + if (t != null) { + feature = Optional.of(t); + } + } if (feature.isEmpty()) { throw new IllegalArgumentException("missing query feature: " + queryFeatureName); } diff --git a/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java index ce9ac377908..f55130c0c93 100644 --- a/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java +++ b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java @@ -60,17 +60,20 @@ public class GlobalPhaseRerankHitsImplTest { static NormalizerSetup makeNormalizer(String name, List expected, FunEvalSpec evalSpec) { return new NormalizerSetup(name, () -> new ExpectingNormalizer(expected), evalSpec); } - static GlobalPhaseSetup makeFullSetup(FunEvalSpec mainSpec, int rerankCount, - List hiddenMF, List normalizers) - { - return new GlobalPhaseSetup(mainSpec, rerankCount, hiddenMF, normalizers); - } - static GlobalPhaseSetup makeSimpleSetup(FunEvalSpec mainSpec, int rerankCount) { - return makeFullSetup(mainSpec, rerankCount, Collections.emptyList(), Collections.emptyList()); - } - static GlobalPhaseSetup makeNormSetup(FunEvalSpec mainSpec, List normalizers) { - return makeFullSetup(mainSpec, 100, Collections.emptyList(), normalizers); - } + static class SetupBuilder { + FunEvalSpec mainSpec = makeConstSpec(0.0); + int rerankCount = 100; + List hiddenMF = new ArrayList<>(); + List normalizers = new ArrayList<>(); + Map defaultValues = new HashMap<>(); + SetupBuilder eval(FunEvalSpec spec) { mainSpec = spec; return this; } + SetupBuilder rerank(int value) { rerankCount = value; return this; } + SetupBuilder hide(String mf) { hiddenMF.add(mf); return this; } + SetupBuilder addNormalizer(NormalizerSetup normalizer) { normalizers.add(normalizer); return this; } + SetupBuilder addDefault(String name, Tensor value) { defaultValues.put(name, value); return this; } + GlobalPhaseSetup build() { return new GlobalPhaseSetup(mainSpec, rerankCount, hiddenMF, normalizers, defaultValues); } + } + static SetupBuilder setup() { return new SetupBuilder(); } static record NamedValue(String name, double value) {} NamedValue value(String name, double value) { return new NamedValue(name, value); @@ -167,7 +170,7 @@ public class GlobalPhaseRerankHitsImplTest { } } @Test void partialRerankWithRescaling() { - var setup = makeSimpleSetup(makeConstSpec(3.0), 2); + var setup = setup().rerank(2).eval(makeConstSpec(3.0)).build(); var query = makeQuery(Collections.emptyList()); var result = makeResult(query, List.of(hit("a", 3), hit("b", 4), hit("c", 5), hit("d", 6))); var expect = Expect.make(List.of(hit("a", 1), hit("b", 2), hit("c", 3), hit("d", 3))); @@ -175,8 +178,7 @@ public class GlobalPhaseRerankHitsImplTest { expect.verifyScores(result); } @Test void matchFeaturesCanBePartiallyHidden() { - var setup = makeFullSetup(makeSumSpec(Collections.emptyList(), List.of("public_value", "private_value")), 2, - List.of("private_value"), Collections.emptyList()); + var setup = setup().eval(makeSumSpec(Collections.emptyList(), List.of("public_value", "private_value"))).hide("private_value").build(); var query = makeQuery(Collections.emptyList()); var factory = new HitFactory(List.of("public_value", "private_value")); var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("public_value", 2), value("private_value", 3))), @@ -188,8 +190,7 @@ public class GlobalPhaseRerankHitsImplTest { verifyDoesNotHaveMF(result, "private_value"); } @Test void matchFeaturesCanBeRemoved() { - var setup = makeFullSetup(makeSumSpec(Collections.emptyList(), List.of("private_value")), 2, - List.of("private_value"), Collections.emptyList()); + var setup = setup().eval(makeSumSpec(Collections.emptyList(), List.of("private_value"))).hide("private_value").build(); var query = makeQuery(Collections.emptyList()); var factory = new HitFactory(List.of("private_value")); var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("private_value", 3))), @@ -200,7 +201,7 @@ public class GlobalPhaseRerankHitsImplTest { verifyDoesNotHaveMatchFeaturesField(result); } @Test void queryFeaturesCanBeUsed() { - var setup = makeSimpleSetup(makeSumSpec(List.of("foo"), List.of("bar")), 2); + var setup = setup().eval(makeSumSpec(List.of("foo"), List.of("bar"))).build(); var query = makeQuery(List.of(value("query(foo)", 7))); var factory = new HitFactory(List.of("bar")); var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("bar", 2))), @@ -211,7 +212,7 @@ public class GlobalPhaseRerankHitsImplTest { verifyHasMF(result, "bar"); } @Test void queryFeaturesCanBeUsedWhenPrepared() { - var setup = makeSimpleSetup(makeSumSpec(List.of("foo"), List.of("bar")), 2); + var setup = setup().eval(makeSumSpec(List.of("foo"), List.of("bar"))).build(); var query = makeQueryWithPrepare(List.of(value("query(foo)", 7))); var factory = new HitFactory(List.of("bar")); var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("bar", 2))), @@ -221,9 +222,18 @@ public class GlobalPhaseRerankHitsImplTest { expect.verifyScores(result); verifyHasMF(result, "bar"); } + @Test void queryFeaturesCanBeDefaultValues() { + var setup = setup().eval(makeSumSpec(List.of("foo", "bar"), Collections.emptyList())) + .addDefault("query(bar)", Tensor.from(5.0)).build(); + var query = makeQuery(List.of(value("query(foo)", 7))); + var result = makeResult(query, List.of(hit("a", 1))); + var expect = Expect.make(List.of(hit("a", 12))); + GlobalPhaseRanker.rerankHitsImpl(setup, query, result); + expect.verifyScores(result); + } @Test void withNormalizer() { - var setup = makeNormSetup(makeSumSpec(Collections.emptyList(), List.of("bar")), - List.of(makeNormalizer("foo", List.of(115.0, 65.0, 55.0, 45.0, 15.0), makeSumSpec(List.of("x"), List.of("bar"))))); + var setup = setup().eval(makeSumSpec(Collections.emptyList(), List.of("bar"))) + .addNormalizer(makeNormalizer("foo", List.of(115.0, 65.0, 55.0, 45.0, 15.0), makeSumSpec(List.of("x"), List.of("bar")))).build(); var query = makeQuery(List.of(value("query(x)", 5))); var factory = new HitFactory(List.of("bar")); var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("bar", 10))), -- cgit v1.2.3