From 484423b747694ae254dfd128fe51c8047e3da349 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Mon, 16 Oct 2023 16:37:00 +0200 Subject: test global phase reranking with dummy evaluation --- .../yahoo/search/ranking/GlobalPhaseRanker.java | 17 +- .../com/yahoo/search/ranking/PreparedInput.java | 4 +- .../ranking/GlobalPhaseRerankHitsImplTest.java | 238 +++++++++++++++++++++ 3 files changed, 252 insertions(+), 7 deletions(-) create mode 100644 container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java index bb5a991c304..829d0c268e5 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java @@ -50,9 +50,7 @@ public class GlobalPhaseRanker { return Optional.empty(); } - public void rerankHits(Query query, Result result, String schema) { - var setup = globalPhaseSetupFor(query, schema).orElse(null); - if (setup == null) return; + static void rerankHitsImpl(GlobalPhaseSetup setup, Query query, Result result) { var mainSpec = setup.globalPhaseEvalSpec; var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), query); int rerankCount = resolveRerankCount(setup, query); @@ -68,6 +66,13 @@ public class GlobalPhaseRanker { hideImplicitMatchFeatures(result, setup.matchFeaturesToHide); } + public void rerankHits(Query query, Result result, String schema) { + var setup = globalPhaseSetupFor(query, schema); + if (setup.isPresent()) { + rerankHitsImpl(setup.get(), query, result); + } + } + static Supplier withQueryPrep(Supplier evalSource, List queryFeatures, Query query) { var prepared = PreparedInput.findFromQuery(query, queryFeatures); Supplier supplier = () -> { @@ -80,7 +85,7 @@ public class GlobalPhaseRanker { return supplier; } - private void hideImplicitMatchFeatures(Result result, Collection namesToHide) { + private static void hideImplicitMatchFeatures(Result result, Collection namesToHide) { if (namesToHide.size() == 0) return; var filter = new MatchFeatureFilter(namesToHide); for (var iterator = result.hits().deepIterator(); iterator.hasNext();) { @@ -94,7 +99,7 @@ public class GlobalPhaseRanker { if (newValue.fieldCount() == 0) { hit.removeField("matchfeatures"); } else { - hit.setField("matchfeatures", newValue); + hit.setField("matchfeatures", new FeatureData(newValue)); } } } @@ -106,7 +111,7 @@ public class GlobalPhaseRanker { .flatMap(evaluator -> evaluator.getGlobalPhaseSetup(query.getRanking().getProfile())); } - private int resolveRerankCount(GlobalPhaseSetup setup, Query query) { + private static int resolveRerankCount(GlobalPhaseSetup setup, Query query) { if (setup == null) { // there is no global-phase at all (ignore override) return 0; diff --git a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java index 5ab2d7160f9..346acccd916 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java @@ -30,7 +30,7 @@ record PreparedInput(String name, Tensor value) { for (String queryFeatureName : queryFeatures) { String needed = "query(" + queryFeatureName + ")"; // searchers are recommended to place query features here: - var feature = rankFeatures.getTensor(queryFeatureName); + var feature = rankFeatures.getTensor(needed); if (feature.isPresent()) { result.add(new PreparedInput(needed, feature.get())); } else { @@ -38,6 +38,8 @@ record PreparedInput(String name, Tensor value) { var objList = rankProps.get(queryFeatureName); if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) { result.add(new PreparedInput(needed, t)); + } else if (objList != null && objList.size() == 1 && objList.get(0) instanceof Double d) { + result.add(new PreparedInput(needed, Tensor.from(d))); } else { throw new IllegalArgumentException("missing query feature: " + queryFeatureName); } diff --git a/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java new file mode 100644 index 00000000000..ce9ac377908 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseRerankHitsImplTest.java @@ -0,0 +1,238 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.data.access.Inspectable; +import com.yahoo.data.access.Type; +import com.yahoo.data.access.helpers.MatchFeatureData; +import com.yahoo.data.access.simple.Value; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; +import com.yahoo.tensor.Tensor; +import org.junit.jupiter.api.Test; + +import java.util.*; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.*; + +public class GlobalPhaseRerankHitsImplTest { + static class EvalSum implements Evaluator { + double baseValue; + List values = new ArrayList<>(); + EvalSum(double baseValue) { this.baseValue = baseValue; } + @Override public Evaluator bind(String name, Tensor value) { + values.add(value); + return this; + } + @Override public double evaluateScore() { + double result = baseValue; + for (var value: values) { + result += value.asDouble(); + } + return result; + } + } + static FunEvalSpec makeConstSpec(double constValue) { + return new FunEvalSpec(() -> new EvalSum(constValue), Collections.emptyList(), Collections.emptyList()); + } + static FunEvalSpec makeSumSpec(List fromQuery, List fromMF) { + return new FunEvalSpec(() -> new EvalSum(0.0), fromQuery, fromMF); + } + static class ExpectingNormalizer extends Normalizer { + List expected; + ExpectingNormalizer(List expected) { + super(100); + this.expected = expected; + } + @Override void normalize() { + double rank = 1; + assertEquals(size, expected.size()); + for (int i = 0; i < size; i++) { + assertEquals(data[i], expected.get(i)); + data[i] = rank; + rank += 1; + } + } + @Override String normalizing() { return "expecting"; } + } + static NormalizerSetup makeNormalizer(String name, List expected, FunEvalSpec evalSpec) { + return new NormalizerSetup(name, () -> new ExpectingNormalizer(expected), evalSpec); + } + static GlobalPhaseSetup makeFullSetup(FunEvalSpec mainSpec, int rerankCount, + List hiddenMF, List normalizers) + { + return new GlobalPhaseSetup(mainSpec, rerankCount, hiddenMF, normalizers); + } + static GlobalPhaseSetup makeSimpleSetup(FunEvalSpec mainSpec, int rerankCount) { + return makeFullSetup(mainSpec, rerankCount, Collections.emptyList(), Collections.emptyList()); + } + static GlobalPhaseSetup makeNormSetup(FunEvalSpec mainSpec, List normalizers) { + return makeFullSetup(mainSpec, 100, Collections.emptyList(), normalizers); + } + static record NamedValue(String name, double value) {} + NamedValue value(String name, double value) { + return new NamedValue(name, value); + } + Query makeQuery(List inQuery, boolean withPrepare) { + var query = new Query(); + for (var v: inQuery) { + query.getRanking().getFeatures().put(v.name, v.value); + } + if (withPrepare) { + query.getRanking().prepare(); + } + return query; + } + Query makeQuery(List inQuery) { return makeQuery(inQuery, false); } + Query makeQueryWithPrepare(List inQuery) { return makeQuery(inQuery, true); } + + static Hit makeHit(String id, double score, FeatureData mf) { + Hit hit = new Hit(id, score); + hit.setField("matchfeatures", mf); + return hit; + } + static Hit hit(String id, double score) { + return makeHit(id, score, FeatureData.empty()); + } + static class HitFactory { + MatchFeatureData mfData; + Map map = new HashMap<>(); + HitFactory(List mfNames) { + int i = 0; + for (var name: mfNames) { + map.put(name, i++); + } + mfData = new MatchFeatureData(mfNames); + } + Hit create(String id, double score, List inMF) { + var mf = mfData.addHit(); + for (var v: inMF) { + var idx = map.get(v.name); + assertNotNull(idx); + mf.set(idx, v.value); + } + return makeHit(id, score, new FeatureData(mf)); + } + } + Result makeResult(Query query, List hits) { + var result = new Result(query); + result.hits().addAll(hits); + return result; + } + static class Expect { + Map map = new HashMap<>(); + static Expect make(List hits) { + var result = new Expect(); + for (var hit : hits) { + result.map.put(hit.getId().stringValue(), hit.getRelevance().getScore()); + } + return result; + } + void verifyScores(Result actual) { + double prev = Double.MAX_VALUE; + assertEquals(actual.hits().size(), map.size()); + for (var hit : actual.hits()) { + var name = hit.getId().stringValue(); + var score = map.get(name); + assertNotNull(score, name); + assertEquals(score.doubleValue(), hit.getRelevance().getScore(), name); + assertTrue(score <= prev); + prev = score; + } + } + } + void verifyHasMF(Result result, String name) { + for (var hit: result.hits()) { + if (hit.getField("matchfeatures") instanceof FeatureData mf) { + assertNotNull(mf.getTensor(name)); + } else { + fail("matchfeatures are missing"); + } + } + } + void verifyDoesNotHaveMF(Result result, String name) { + for (var hit: result.hits()) { + if (hit.getField("matchfeatures") instanceof FeatureData mf) { + assertNull(mf.getTensor(name)); + } else { + fail("matchfeatures are missing"); + } + } + } + void verifyDoesNotHaveMatchFeaturesField(Result result) { + for (var hit: result.hits()) { + assertNull(hit.getField("matchfeatures")); + } + } + @Test void partialRerankWithRescaling() { + var setup = makeSimpleSetup(makeConstSpec(3.0), 2); + var query = makeQuery(Collections.emptyList()); + var result = makeResult(query, List.of(hit("a", 3), hit("b", 4), hit("c", 5), hit("d", 6))); + var expect = Expect.make(List.of(hit("a", 1), hit("b", 2), hit("c", 3), hit("d", 3))); + GlobalPhaseRanker.rerankHitsImpl(setup, query, result); + expect.verifyScores(result); + } + @Test void matchFeaturesCanBePartiallyHidden() { + var setup = makeFullSetup(makeSumSpec(Collections.emptyList(), List.of("public_value", "private_value")), 2, + List.of("private_value"), Collections.emptyList()); + var query = makeQuery(Collections.emptyList()); + var factory = new HitFactory(List.of("public_value", "private_value")); + var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("public_value", 2), value("private_value", 3))), + factory.create("b", 2, List.of(value("public_value", 5), value("private_value", 7))))); + var expect = Expect.make(List.of(hit("a", 5), hit("b", 12))); + GlobalPhaseRanker.rerankHitsImpl(setup, query, result); + expect.verifyScores(result); + verifyHasMF(result, "public_value"); + verifyDoesNotHaveMF(result, "private_value"); + } + @Test void matchFeaturesCanBeRemoved() { + var setup = makeFullSetup(makeSumSpec(Collections.emptyList(), List.of("private_value")), 2, + List.of("private_value"), Collections.emptyList()); + var query = makeQuery(Collections.emptyList()); + var factory = new HitFactory(List.of("private_value")); + var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("private_value", 3))), + factory.create("b", 2, List.of(value("private_value", 7))))); + var expect = Expect.make(List.of(hit("a", 3), hit("b", 7))); + GlobalPhaseRanker.rerankHitsImpl(setup, query, result); + expect.verifyScores(result); + verifyDoesNotHaveMatchFeaturesField(result); + } + @Test void queryFeaturesCanBeUsed() { + var setup = makeSimpleSetup(makeSumSpec(List.of("foo"), List.of("bar")), 2); + var query = makeQuery(List.of(value("query(foo)", 7))); + var factory = new HitFactory(List.of("bar")); + var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("bar", 2))), + factory.create("b", 2, List.of(value("bar", 5))))); + var expect = Expect.make(List.of(hit("a", 9), hit("b", 12))); + GlobalPhaseRanker.rerankHitsImpl(setup, query, result); + expect.verifyScores(result); + verifyHasMF(result, "bar"); + } + @Test void queryFeaturesCanBeUsedWhenPrepared() { + var setup = makeSimpleSetup(makeSumSpec(List.of("foo"), List.of("bar")), 2); + var query = makeQueryWithPrepare(List.of(value("query(foo)", 7))); + var factory = new HitFactory(List.of("bar")); + var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("bar", 2))), + factory.create("b", 2, List.of(value("bar", 5))))); + var expect = Expect.make(List.of(hit("a", 9), hit("b", 12))); + GlobalPhaseRanker.rerankHitsImpl(setup, query, result); + expect.verifyScores(result); + verifyHasMF(result, "bar"); + } + @Test void withNormalizer() { + var setup = makeNormSetup(makeSumSpec(Collections.emptyList(), List.of("bar")), + List.of(makeNormalizer("foo", List.of(115.0, 65.0, 55.0, 45.0, 15.0), makeSumSpec(List.of("x"), List.of("bar"))))); + var query = makeQuery(List.of(value("query(x)", 5))); + var factory = new HitFactory(List.of("bar")); + var result = makeResult(query, List.of(factory.create("a", 1, List.of(value("bar", 10))), + factory.create("b", 2, List.of(value("bar", 40))), + factory.create("c", 3, List.of(value("bar", 50))), + factory.create("d", 4, List.of(value("bar", 60))), + factory.create("e", 5, List.of(value("bar", 110))))); + var expect = Expect.make(List.of(hit("a", 15), hit("b", 44), hit("c", 53), hit("d", 62), hit("e", 111))); + GlobalPhaseRanker.rerankHitsImpl(setup, query, result); + expect.verifyScores(result); + } +} -- cgit v1.2.3