aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics/src/test/java/com/yahoo
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-13 19:29:36 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-13 19:29:36 +0200
commitbfae6ba4377ea51a7cc024c5e9772ed069feedf4 (patch)
treeff20b7491674e60a9851ee4bdb576677bb4fc486 /linguistics/src/test/java/com/yahoo
parentdd26caaa5cc05e9cacaa280a4bee5d9ddb56ecbc (diff)
Pure Java sentencepiece implementation
Diffstat (limited to 'linguistics/src/test/java/com/yahoo')
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java78
1 files changed, 78 insertions, 0 deletions
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
new file mode 100644
index 00000000000..70361f55750
--- /dev/null
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -0,0 +1,78 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.language.Language;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * @author bratseth
+ */
+public class SentencePieceTest {
+
+ @Test
+ public void testEnglishTokenization() throws IOException {
+ var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ tester.assertSegmented("h", "▁h");
+ tester.assertSegmented("he", "▁he");
+ tester.assertSegmented("hel", "▁hel");
+ tester.assertSegmented("hello", "▁hel", "lo");
+ tester.assertSegmented("hei", "▁he", "i");
+ tester.assertSegmented("hei you", "▁he", "i", "▁you");
+ tester.assertSegmented("hei you", "▁he", "i", "▁you");
+ tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence");
+ tester.assertSegmented("hello world!", "▁hel", "lo", "▁world", "!");
+ tester.assertSegmented("Hello, world!", "▁", "H", "ello", ",", "▁world", "!");
+ tester.assertSegmented("HELLO, world!", "▁", "HELLO", ",", "▁world", "!");
+ tester.assertSegmented("KHJKJHHKJHHSH", "▁", "KHJKJHHKJHHSH");
+ tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo");
+ tester.assertSegmented(" hello ", "▁hel", "lo");
+ tester.assertSegmented(")(/&#()/\"\")", "▁)", "(", "/", "&", "#", "(", ")", "/", "\"", "\")");
+ tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁)", "(", "/", "&", "#", "(", "sm", "all", ")", "/", "\"", "in", "▁qu", "otes", "\")");
+ tester.assertSegmented("x.400AS", "▁x", ".", "4", "00", "AS");
+ tester.assertSegmented("A normal sentence. Yes one more.", "▁", "A", "▁normal", "▁sentence", ".", "▁", "Y", "es", "▁one", "▁more", ".");
+ tester.assertEncoded("hello, world!", 908, 1418, 9934, 501, 9960);
+ tester.assertEncoded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960);
+ }
+
+ @Test
+ public void testJapaneseTokenization() throws IOException {
+ SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder();
+ builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath());
+ builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ var tester = new SentencePieceTester(builder);
+ tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
+ }
+
+ private static class SentencePieceTester {
+
+ private final SentencePieceEncoder encoder;
+
+ public SentencePieceTester(Path model) {
+ this(new SentencePieceEncoder.Builder().addDefaultModel(model));
+ }
+
+ public SentencePieceTester(SentencePieceEncoder.Builder builder) {
+ encoder = builder.build();
+ }
+
+ private void assertEncoded(String input, Integer ... expectedCodes) {
+ assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
+ }
+
+ private void assertSegmented(String input, String ... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+ private void assertSegmented(Language language, String input, String ... expectedSegments) {
+ assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
+ }
+
+ }
+
+}