summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorMariusArhaug <mariusarhaug@hotmail.com>2024-04-10 14:33:36 +0200
committerMariusArhaug <mariusarhaug@hotmail.com>2024-04-10 14:33:36 +0200
commit9125b714f995bcf8d734cf155a7f17d60d43fdca (patch)
tree5628db7798de009462765d33fa4b7e91341658a5 /container-search
parent91cc80f7accdc9a5456c4fcb1c03002552586aa5 (diff)
add tests
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/significance/SignificanceSearcher.java19
-rw-r--r--container-search/src/test/java/com/yahoo/search/significance/model/en.json14
-rw-r--r--container-search/src/test/java/com/yahoo/search/significance/test/SignificanceSearcherTest.java113
3 files changed, 138 insertions, 8 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/significance/SignificanceSearcher.java b/container-search/src/main/java/com/yahoo/search/significance/SignificanceSearcher.java
index f33d1468334..7403d57b71f 100644
--- a/container-search/src/main/java/com/yahoo/search/significance/SignificanceSearcher.java
+++ b/container-search/src/main/java/com/yahoo/search/significance/SignificanceSearcher.java
@@ -1,7 +1,9 @@
package com.yahoo.search.significance;
import com.yahoo.component.annotation.Inject;
+import com.yahoo.component.chain.dependencies.Before;
import com.yahoo.component.chain.dependencies.Provides;
+import com.yahoo.language.Language;
import com.yahoo.language.significance.SignificanceModel;
import com.yahoo.language.significance.SignificanceModelRegistry;
import com.yahoo.prelude.query.CompositeItem;
@@ -13,7 +15,10 @@ import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.searchchain.Execution;
+import static com.yahoo.prelude.querytransform.StemmingSearcher.STEMMING;
+
@Provides(SignificanceSearcher.SIGNIFICANCE)
+@Before(STEMMING)
public class SignificanceSearcher extends Searcher {
public final static String SIGNIFICANCE = "Significance";
@@ -29,33 +34,31 @@ public class SignificanceSearcher extends Searcher {
public Result search(Query query, Execution execution) {
if (significanceModelRegistry == null) return execution.search(query);
-
- setIDF(query.getModel().getQueryTree().getRoot());
+ Language language = query.getModel().getParsingLanguage();
+ setIDF(query.getModel().getQueryTree().getRoot(), significanceModelRegistry.getModel(language));
return execution.search(query);
}
- private void setIDF(Item root) {
+ private void setIDF(Item root, SignificanceModel significanceModel) {
if (root == null || root instanceof NullItem) return;
if (root instanceof WordItem) {
- SignificanceModel significanceModel = significanceModelRegistry.getModel(root.getLanguage());
-
var documentFrequency = significanceModel.documentFrequency(((WordItem) root).getWord());
- long nq_i = documentFrequency.frequency();
long N = documentFrequency.corpusSize();
+ long nq_i = documentFrequency.frequency();
double idf = calculateIDF(N, nq_i);
((WordItem) root).setSignificance(idf);
} else if (root instanceof CompositeItem) {
for (int i = 0; i < ((CompositeItem) root).getItemCount(); i++) {
- setIDF(((CompositeItem) root).getItem(i));
+ setIDF(((CompositeItem) root).getItem(i), significanceModel);
}
}
}
- private static double calculateIDF(long N, long nq_i) {
+ public static double calculateIDF(long N, long nq_i) {
return Math.log(1 + (N - nq_i + 0.5) / (nq_i + 0.5));
}
}
diff --git a/container-search/src/test/java/com/yahoo/search/significance/model/en.json b/container-search/src/test/java/com/yahoo/search/significance/model/en.json
new file mode 100644
index 00000000000..50bae5e3451
--- /dev/null
+++ b/container-search/src/test/java/com/yahoo/search/significance/model/en.json
@@ -0,0 +1,14 @@
+{
+ "version" : "1.0",
+ "id" : "test::1",
+ "description" : "desc",
+ "corpus-size" : 10,
+ "language" : "en",
+ "word-count" : 4,
+ "frequencies" : {
+ "usa" : 2,
+ "hello": 3,
+ "world": 5,
+ "test": 2
+ }
+}
diff --git a/container-search/src/test/java/com/yahoo/search/significance/test/SignificanceSearcherTest.java b/container-search/src/test/java/com/yahoo/search/significance/test/SignificanceSearcherTest.java
new file mode 100644
index 00000000000..389236af31b
--- /dev/null
+++ b/container-search/src/test/java/com/yahoo/search/significance/test/SignificanceSearcherTest.java
@@ -0,0 +1,113 @@
+package com.yahoo.search.significance.test;
+
+import com.yahoo.component.chain.Chain;
+import com.yahoo.language.Language;
+import com.yahoo.language.significance.SignificanceModel;
+import com.yahoo.language.significance.SignificanceModelRegistry;
+import com.yahoo.language.significance.impl.DefaultSignificanceModelRegistry;
+import com.yahoo.prelude.query.AndItem;
+import com.yahoo.prelude.query.WordItem;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.search.significance.SignificanceSearcher;
+import org.junit.jupiter.api.Test;
+
+import java.nio.file.Path;
+import java.util.HashMap;
+
+import static com.yahoo.test.JunitCompat.assertEquals;
+
+public class SignificanceSearcherTest {
+ SignificanceModelRegistry significanceModelRegistry;
+ SignificanceSearcher searcher;
+
+ public SignificanceSearcherTest() {
+ HashMap<Language, Path> map = new HashMap<>();
+ map.put(Language.ENGLISH, Path.of("src/test/java/com/yahoo/search/significance/model/en.json"));
+
+ significanceModelRegistry = new DefaultSignificanceModelRegistry(map);
+ // TODO change to mock
+ searcher = new SignificanceSearcher(significanceModelRegistry);
+ }
+
+ private Execution createExecution() {
+ return new Execution(new Chain<>(searcher), Execution.Context.createContextStub());
+ }
+
+ @Test
+ void testSimpleSignificanceValue() {
+
+ Query q = new Query();
+ AndItem root = new AndItem();
+ WordItem tmp;
+ tmp = new WordItem("Hello", true);
+ root.addItem(tmp);
+ tmp = new WordItem("world", true);
+ root.addItem(tmp);
+
+ q.getModel().getQueryTree().setRoot(root);
+
+ SignificanceModel model = significanceModelRegistry.getModel(Language.ENGLISH);
+ var helloFrequency = model.documentFrequency("Hello");
+ var helloSignificanceValue = SignificanceSearcher.calculateIDF(helloFrequency.corpusSize(), helloFrequency.frequency());
+
+ var worldFrequency = model.documentFrequency("world");
+ var worldSignificanceValue = SignificanceSearcher.calculateIDF(worldFrequency.corpusSize(), worldFrequency.frequency());
+
+ Result r = createExecution().search(q);
+
+ root = (AndItem) r.getQuery().getModel().getQueryTree().getRoot();
+ WordItem w0 = (WordItem) root.getItem(0);
+ WordItem w1 = (WordItem) root.getItem(1);
+
+ assertEquals(helloSignificanceValue, w0.getSignificance());
+ assertEquals(worldSignificanceValue, w1.getSignificance());
+
+ }
+
+ @Test
+ void testRecursiveSignificanceValues() {
+ Query q = new Query();
+ AndItem root = new AndItem();
+ WordItem child1 = new WordItem("hello", true);
+
+ AndItem child2 = new AndItem();
+ WordItem child2_1 = new WordItem("test", true);
+
+ AndItem child3 = new AndItem();
+ AndItem child3_1 = new AndItem();
+ WordItem child3_1_1 = new WordItem("usa", true);
+
+ root.addItem(child1);
+ root.addItem(child2);
+ root.addItem(child3);
+
+ child2.addItem(child2_1);
+ child3.addItem(child3_1);
+ child3_1.addItem(child3_1_1);
+
+ q.getModel().getQueryTree().setRoot(root);
+
+ SignificanceModel model = significanceModelRegistry.getModel(Language.ENGLISH);
+ var helloFrequency = model.documentFrequency("hello");
+ var helloSignificanceValue = SignificanceSearcher.calculateIDF(helloFrequency.corpusSize(), helloFrequency.frequency());
+
+ var testFrequency = model.documentFrequency("test");
+ var testSignificanceValue = SignificanceSearcher.calculateIDF(testFrequency.corpusSize(), testFrequency.frequency());
+
+
+
+ Result r = createExecution().search(q);
+
+ root = (AndItem) r.getQuery().getModel().getQueryTree().getRoot();
+ WordItem w0 = (WordItem) root.getItem(0);
+ WordItem w1 = (WordItem) ((AndItem) root.getItem(1)).getItem(0);
+ WordItem w3 = (WordItem) ((AndItem) ((AndItem) root.getItem(2)).getItem(0)).getItem(0);
+
+ assertEquals(helloSignificanceValue, w0.getSignificance());
+ assertEquals(testSignificanceValue, w1.getSignificance());
+ assertEquals(SignificanceSearcher.calculateIDF(10, 2), w3.getSignificance());
+
+ }
+}