aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics/src/test/java/com/yahoo/language/sentencepiece
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-16 11:04:56 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-16 11:04:56 +0200
commita2afdafbffcbc09594fd629c65746ec253f180be (patch)
tree987804d4362c801c6c47cfe469ec6b97409de6ef /linguistics/src/test/java/com/yahoo/language/sentencepiece
parent381033510b992049d55cae9964d942b4b47eb9df (diff)
Make SentencePieceEncoder configurable
Diffstat (limited to 'linguistics/src/test/java/com/yahoo/language/sentencepiece')
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java59
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java33
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java40
3 files changed, 102 insertions, 30 deletions
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
new file mode 100644
index 00000000000..edbbe21ec53
--- /dev/null
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
@@ -0,0 +1,59 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.config.FileReference;
+import com.yahoo.language.Language;
+import org.junit.Test;
+
+/**
+ * @author bratseth
+ */
+public class SentencePieceConfigurationTest {
+
+ @Test
+ public void testEnglishTokenization() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence");
+ tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo");
+ }
+
+ @Test
+ public void testNoCollapse() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ b.collapseUnknowns(false);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
+ }
+
+ @Test
+ public void testHighestScore() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ b.scoring(SentencePieceConfig.Scoring.highestScore);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented("hello", "▁h", "el", "lo");
+ }
+
+ @Test
+ public void testMultiLanguageTokenization() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b);
+ addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
+ tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
+ tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
+ }
+
+ private void addModel(String language, String file, SentencePieceConfig.Builder b) {
+ var mb = new SentencePieceConfig.Model.Builder();
+ mb.language(language);
+ mb.path(new FileReference(file));
+ b.model(mb);
+ }
+
+}
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index 7d0c1c5c78e..f86bc2f716b 100644
--- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -6,10 +6,6 @@ import com.yahoo.language.Language;
import org.junit.Test;
import java.io.File;
-import java.io.IOException;
-import java.nio.file.Path;
-
-import static org.junit.Assert.assertArrayEquals;
/**
* @author bratseth
@@ -61,37 +57,14 @@ public class SentencePieceTest {
}
@Test
- public void testJapaneseTokenization() throws IOException {
+ public void testMultiLanguageTokenization() {
SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder();
builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath());
builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
var tester = new SentencePieceTester(builder);
tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
- }
-
- private static class SentencePieceTester {
-
- private final SentencePieceEncoder encoder;
-
- public SentencePieceTester(Path model) {
- this(new SentencePieceEncoder.Builder().addDefaultModel(model));
- }
-
- public SentencePieceTester(SentencePieceEncoder.Builder builder) {
- encoder = builder.build();
- }
-
- private void assertEncoded(String input, Integer ... expectedCodes) {
- assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
- }
-
- private void assertSegmented(String input, String ... expectedSegments) {
- assertSegmented(Language.UNKNOWN, input, expectedSegments);
- }
- private void assertSegmented(Language language, String input, String ... expectedSegments) {
- assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
- }
-
+ tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
+ tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
}
}
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
new file mode 100644
index 00000000000..dee9be5aa7e
--- /dev/null
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -0,0 +1,40 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+//
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.language.Language;
+
+import java.nio.file.Path;
+
+import static org.junit.Assert.assertArrayEquals;
+
+class SentencePieceTester {
+
+ private final SentencePieceEncoder encoder;
+
+ public SentencePieceTester(Path model) {
+ this(new SentencePieceEncoder.Builder().addDefaultModel(model));
+ }
+
+ public SentencePieceTester(SentencePieceEncoder.Builder builder) {
+ this(builder.build());
+ }
+
+ public SentencePieceTester(SentencePieceEncoder encoder) {
+ this.encoder = encoder;
+ }
+
+ public void assertEncoded(String input, Integer... expectedCodes) {
+ assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
+ }
+
+ public void assertSegmented(String input, String... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+
+ public void assertSegmented(Language language, String input, String... expectedSegments) {
+ assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
+ }
+
+}