diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-09-16 11:04:56 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-09-16 11:04:56 +0200 |
commit | a2afdafbffcbc09594fd629c65746ec253f180be (patch) | |
tree | 987804d4362c801c6c47cfe469ec6b97409de6ef /linguistics/src/test/java/com/yahoo/language/sentencepiece | |
parent | 381033510b992049d55cae9964d942b4b47eb9df (diff) |
Make SentencePieceEncoder configurable
Diffstat (limited to 'linguistics/src/test/java/com/yahoo/language/sentencepiece')
3 files changed, 102 insertions, 30 deletions
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java new file mode 100644 index 00000000000..edbbe21ec53 --- /dev/null +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java @@ -0,0 +1,59 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.sentencepiece; + +import com.yahoo.config.FileReference; +import com.yahoo.language.Language; +import org.junit.Test; + +/** + * @author bratseth + */ +public class SentencePieceConfigurationTest { + + @Test + public void testEnglishTokenization() { + var b = new SentencePieceConfig.Builder(); + addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); + tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); + } + + @Test + public void testNoCollapse() { + var b = new SentencePieceConfig.Builder(); + addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + b.collapseUnknowns(false); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); + } + + @Test + public void testHighestScore() { + var b = new SentencePieceConfig.Builder(); + addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + b.scoring(SentencePieceConfig.Scoring.highestScore); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented("hello", "▁h", "el", "lo"); + } + + @Test + public void testMultiLanguageTokenization() { + var b = new SentencePieceConfig.Builder(); + addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b); + addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); + tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); + tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); + } + + private void addModel(String language, String file, SentencePieceConfig.Builder b) { + var mb = new SentencePieceConfig.Model.Builder(); + mb.language(language); + mb.path(new FileReference(file)); + b.model(mb); + } + +} diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index 7d0c1c5c78e..f86bc2f716b 100644 --- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -6,10 +6,6 @@ import com.yahoo.language.Language; import org.junit.Test; import java.io.File; -import java.io.IOException; -import java.nio.file.Path; - -import static org.junit.Assert.assertArrayEquals; /** * @author bratseth @@ -61,37 +57,14 @@ public class SentencePieceTest { } @Test - public void testJapaneseTokenization() throws IOException { + public void testMultiLanguageTokenization() { SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder(); builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); var tester = new SentencePieceTester(builder); tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); - } - - private static class SentencePieceTester { - - private final SentencePieceEncoder encoder; - - public SentencePieceTester(Path model) { - this(new SentencePieceEncoder.Builder().addDefaultModel(model)); - } - - public SentencePieceTester(SentencePieceEncoder.Builder builder) { - encoder = builder.build(); - } - - private void assertEncoded(String input, Integer ... expectedCodes) { - assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray()); - } - - private void assertSegmented(String input, String ... expectedSegments) { - assertSegmented(Language.UNKNOWN, input, expectedSegments); - } - private void assertSegmented(Language language, String input, String ... expectedSegments) { - assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray()); - } - + tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); + tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); } } diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java new file mode 100644 index 00000000000..dee9be5aa7e --- /dev/null +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java @@ -0,0 +1,40 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// + +package com.yahoo.language.sentencepiece; + +import com.yahoo.language.Language; + +import java.nio.file.Path; + +import static org.junit.Assert.assertArrayEquals; + +class SentencePieceTester { + + private final SentencePieceEncoder encoder; + + public SentencePieceTester(Path model) { + this(new SentencePieceEncoder.Builder().addDefaultModel(model)); + } + + public SentencePieceTester(SentencePieceEncoder.Builder builder) { + this(builder.build()); + } + + public SentencePieceTester(SentencePieceEncoder encoder) { + this.encoder = encoder; + } + + public void assertEncoded(String input, Integer... expectedCodes) { + assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray()); + } + + public void assertSegmented(String input, String... expectedSegments) { + assertSegmented(Language.UNKNOWN, input, expectedSegments); + } + + public void assertSegmented(Language language, String input, String... expectedSegments) { + assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray()); + } + +} |