diff options
author | MariusArhaug <mariusarhaug@hotmail.com> | 2024-04-10 14:33:36 +0200 |
---|---|---|
committer | MariusArhaug <mariusarhaug@hotmail.com> | 2024-04-10 14:33:36 +0200 |
commit | 9125b714f995bcf8d734cf155a7f17d60d43fdca (patch) | |
tree | 5628db7798de009462765d33fa4b7e91341658a5 /container-search | |
parent | 91cc80f7accdc9a5456c4fcb1c03002552586aa5 (diff) |
add tests
Diffstat (limited to 'container-search')
3 files changed, 138 insertions, 8 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/significance/SignificanceSearcher.java b/container-search/src/main/java/com/yahoo/search/significance/SignificanceSearcher.java index f33d1468334..7403d57b71f 100644 --- a/container-search/src/main/java/com/yahoo/search/significance/SignificanceSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/significance/SignificanceSearcher.java @@ -1,7 +1,9 @@ package com.yahoo.search.significance; import com.yahoo.component.annotation.Inject; +import com.yahoo.component.chain.dependencies.Before; import com.yahoo.component.chain.dependencies.Provides; +import com.yahoo.language.Language; import com.yahoo.language.significance.SignificanceModel; import com.yahoo.language.significance.SignificanceModelRegistry; import com.yahoo.prelude.query.CompositeItem; @@ -13,7 +15,10 @@ import com.yahoo.search.Result; import com.yahoo.search.Searcher; import com.yahoo.search.searchchain.Execution; +import static com.yahoo.prelude.querytransform.StemmingSearcher.STEMMING; + @Provides(SignificanceSearcher.SIGNIFICANCE) +@Before(STEMMING) public class SignificanceSearcher extends Searcher { public final static String SIGNIFICANCE = "Significance"; @@ -29,33 +34,31 @@ public class SignificanceSearcher extends Searcher { public Result search(Query query, Execution execution) { if (significanceModelRegistry == null) return execution.search(query); - - setIDF(query.getModel().getQueryTree().getRoot()); + Language language = query.getModel().getParsingLanguage(); + setIDF(query.getModel().getQueryTree().getRoot(), significanceModelRegistry.getModel(language)); return execution.search(query); } - private void setIDF(Item root) { + private void setIDF(Item root, SignificanceModel significanceModel) { if (root == null || root instanceof NullItem) return; if (root instanceof WordItem) { - SignificanceModel significanceModel = significanceModelRegistry.getModel(root.getLanguage()); - var documentFrequency = significanceModel.documentFrequency(((WordItem) root).getWord()); - long nq_i = documentFrequency.frequency(); long N = documentFrequency.corpusSize(); + long nq_i = documentFrequency.frequency(); double idf = calculateIDF(N, nq_i); ((WordItem) root).setSignificance(idf); } else if (root instanceof CompositeItem) { for (int i = 0; i < ((CompositeItem) root).getItemCount(); i++) { - setIDF(((CompositeItem) root).getItem(i)); + setIDF(((CompositeItem) root).getItem(i), significanceModel); } } } - private static double calculateIDF(long N, long nq_i) { + public static double calculateIDF(long N, long nq_i) { return Math.log(1 + (N - nq_i + 0.5) / (nq_i + 0.5)); } } diff --git a/container-search/src/test/java/com/yahoo/search/significance/model/en.json b/container-search/src/test/java/com/yahoo/search/significance/model/en.json new file mode 100644 index 00000000000..50bae5e3451 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/significance/model/en.json @@ -0,0 +1,14 @@ +{ + "version" : "1.0", + "id" : "test::1", + "description" : "desc", + "corpus-size" : 10, + "language" : "en", + "word-count" : 4, + "frequencies" : { + "usa" : 2, + "hello": 3, + "world": 5, + "test": 2 + } +} diff --git a/container-search/src/test/java/com/yahoo/search/significance/test/SignificanceSearcherTest.java b/container-search/src/test/java/com/yahoo/search/significance/test/SignificanceSearcherTest.java new file mode 100644 index 00000000000..389236af31b --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/significance/test/SignificanceSearcherTest.java @@ -0,0 +1,113 @@ +package com.yahoo.search.significance.test; + +import com.yahoo.component.chain.Chain; +import com.yahoo.language.Language; +import com.yahoo.language.significance.SignificanceModel; +import com.yahoo.language.significance.SignificanceModelRegistry; +import com.yahoo.language.significance.impl.DefaultSignificanceModelRegistry; +import com.yahoo.prelude.query.AndItem; +import com.yahoo.prelude.query.WordItem; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.searchchain.Execution; +import com.yahoo.search.significance.SignificanceSearcher; +import org.junit.jupiter.api.Test; + +import java.nio.file.Path; +import java.util.HashMap; + +import static com.yahoo.test.JunitCompat.assertEquals; + +public class SignificanceSearcherTest { + SignificanceModelRegistry significanceModelRegistry; + SignificanceSearcher searcher; + + public SignificanceSearcherTest() { + HashMap<Language, Path> map = new HashMap<>(); + map.put(Language.ENGLISH, Path.of("src/test/java/com/yahoo/search/significance/model/en.json")); + + significanceModelRegistry = new DefaultSignificanceModelRegistry(map); + // TODO change to mock + searcher = new SignificanceSearcher(significanceModelRegistry); + } + + private Execution createExecution() { + return new Execution(new Chain<>(searcher), Execution.Context.createContextStub()); + } + + @Test + void testSimpleSignificanceValue() { + + Query q = new Query(); + AndItem root = new AndItem(); + WordItem tmp; + tmp = new WordItem("Hello", true); + root.addItem(tmp); + tmp = new WordItem("world", true); + root.addItem(tmp); + + q.getModel().getQueryTree().setRoot(root); + + SignificanceModel model = significanceModelRegistry.getModel(Language.ENGLISH); + var helloFrequency = model.documentFrequency("Hello"); + var helloSignificanceValue = SignificanceSearcher.calculateIDF(helloFrequency.corpusSize(), helloFrequency.frequency()); + + var worldFrequency = model.documentFrequency("world"); + var worldSignificanceValue = SignificanceSearcher.calculateIDF(worldFrequency.corpusSize(), worldFrequency.frequency()); + + Result r = createExecution().search(q); + + root = (AndItem) r.getQuery().getModel().getQueryTree().getRoot(); + WordItem w0 = (WordItem) root.getItem(0); + WordItem w1 = (WordItem) root.getItem(1); + + assertEquals(helloSignificanceValue, w0.getSignificance()); + assertEquals(worldSignificanceValue, w1.getSignificance()); + + } + + @Test + void testRecursiveSignificanceValues() { + Query q = new Query(); + AndItem root = new AndItem(); + WordItem child1 = new WordItem("hello", true); + + AndItem child2 = new AndItem(); + WordItem child2_1 = new WordItem("test", true); + + AndItem child3 = new AndItem(); + AndItem child3_1 = new AndItem(); + WordItem child3_1_1 = new WordItem("usa", true); + + root.addItem(child1); + root.addItem(child2); + root.addItem(child3); + + child2.addItem(child2_1); + child3.addItem(child3_1); + child3_1.addItem(child3_1_1); + + q.getModel().getQueryTree().setRoot(root); + + SignificanceModel model = significanceModelRegistry.getModel(Language.ENGLISH); + var helloFrequency = model.documentFrequency("hello"); + var helloSignificanceValue = SignificanceSearcher.calculateIDF(helloFrequency.corpusSize(), helloFrequency.frequency()); + + var testFrequency = model.documentFrequency("test"); + var testSignificanceValue = SignificanceSearcher.calculateIDF(testFrequency.corpusSize(), testFrequency.frequency()); + + + + Result r = createExecution().search(q); + + root = (AndItem) r.getQuery().getModel().getQueryTree().getRoot(); + WordItem w0 = (WordItem) root.getItem(0); + WordItem w1 = (WordItem) ((AndItem) root.getItem(1)).getItem(0); + WordItem w3 = (WordItem) ((AndItem) ((AndItem) root.getItem(2)).getItem(0)).getItem(0); + + assertEquals(helloSignificanceValue, w0.getSignificance()); + assertEquals(testSignificanceValue, w1.getSignificance()); + assertEquals(SignificanceSearcher.calculateIDF(10, 2), w3.getSignificance()); + + } +} |