From bfae6ba4377ea51a7cc024c5e9772ed069feedf4 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 13 Sep 2021 19:29:36 +0200 Subject: Pure Java sentencepiece implementation --- .../language/sentencepiece/SentencePieceTest.java | 78 ++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java (limited to 'linguistics/src/test/java/com/yahoo') diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java new file mode 100644 index 00000000000..70361f55750 --- /dev/null +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -0,0 +1,78 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.sentencepiece; + +import com.yahoo.language.Language; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; + +import static org.junit.Assert.assertArrayEquals; + +/** + * @author bratseth + */ +public class SentencePieceTest { + + @Test + public void testEnglishTokenization() throws IOException { + var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + tester.assertSegmented("h", "▁h"); + tester.assertSegmented("he", "▁he"); + tester.assertSegmented("hel", "▁hel"); + tester.assertSegmented("hello", "▁hel", "lo"); + tester.assertSegmented("hei", "▁he", "i"); + tester.assertSegmented("hei you", "▁he", "i", "▁you"); + tester.assertSegmented("hei you", "▁he", "i", "▁you"); + tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); + tester.assertSegmented("hello world!", "▁hel", "lo", "▁world", "!"); + tester.assertSegmented("Hello, world!", "▁", "H", "ello", ",", "▁world", "!"); + tester.assertSegmented("HELLO, world!", "▁", "HELLO", ",", "▁world", "!"); + tester.assertSegmented("KHJKJHHKJHHSH", "▁", "KHJKJHHKJHHSH"); + tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); + tester.assertSegmented(" hello ", "▁hel", "lo"); + tester.assertSegmented(")(/&#()/\"\")", "▁)", "(", "/", "&", "#", "(", ")", "/", "\"", "\")"); + tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁)", "(", "/", "&", "#", "(", "sm", "all", ")", "/", "\"", "in", "▁qu", "otes", "\")"); + tester.assertSegmented("x.400AS", "▁x", ".", "4", "00", "AS"); + tester.assertSegmented("A normal sentence. Yes one more.", "▁", "A", "▁normal", "▁sentence", ".", "▁", "Y", "es", "▁one", "▁more", "."); + tester.assertEncoded("hello, world!", 908, 1418, 9934, 501, 9960); + tester.assertEncoded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960); + } + + @Test + public void testJapaneseTokenization() throws IOException { + SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder(); + builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); + builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + var tester = new SentencePieceTester(builder); + tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); + } + + private static class SentencePieceTester { + + private final SentencePieceEncoder encoder; + + public SentencePieceTester(Path model) { + this(new SentencePieceEncoder.Builder().addDefaultModel(model)); + } + + public SentencePieceTester(SentencePieceEncoder.Builder builder) { + encoder = builder.build(); + } + + private void assertEncoded(String input, Integer ... expectedCodes) { + assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray()); + } + + private void assertSegmented(String input, String ... expectedSegments) { + assertSegmented(Language.UNKNOWN, input, expectedSegments); + } + private void assertSegmented(Language language, String input, String ... expectedSegments) { + assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray()); + } + + } + +} -- cgit v1.2.3