diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-12-17 12:41:17 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-12-17 12:41:17 +0100 |
commit | 601b117281b74a578126a0f3effead55bc79c680 (patch) | |
tree | 29619184a8459763cc024b23e74960e6c9ec7f81 /linguistics-components/src/test/java/com/yahoo/language/sentencepiece | |
parent | 767cb63af0f530605180f5438767406e1db27520 (diff) |
BERT -> WordPiece, make subword prefix configurable
Diffstat (limited to 'linguistics-components/src/test/java/com/yahoo/language/sentencepiece')
3 files changed, 22 insertions, 75 deletions
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java index 1ed2271f774..19cb2267655 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java @@ -4,6 +4,7 @@ package com.yahoo.language.sentencepiece; import com.yahoo.config.FileReference; import com.yahoo.language.Language; +import com.yahoo.language.tools.EmbedderTester; import org.junit.Test; /** @@ -15,7 +16,7 @@ public class SentencePieceConfigurationTest { public void testEnglishTokenization() { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); } @@ -25,7 +26,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); b.collapseUnknowns(false); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); } @@ -34,7 +35,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); b.scoring(SentencePieceConfig.Scoring.highestScore); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("hello", "▁h", "el", "lo"); } @@ -43,7 +44,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b); addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index 8b3e2988c43..2fbafb23485 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -3,6 +3,7 @@ package com.yahoo.language.sentencepiece; import com.yahoo.language.Language; +import com.yahoo.language.tools.EmbedderTester; import org.junit.Test; import java.io.File; @@ -13,8 +14,8 @@ import java.io.File; public class SentencePieceTest { @Test - public void testEnglishTokenization() { - var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + public void testEnglishSegmenting() { + var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build()); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); tester.assertSegmented("hel", "▁hel"); @@ -36,33 +37,28 @@ public class SentencePieceTest { } @Test - public void testIntegerListEncoding() { - var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - tester.assertEmbedded("hello, world!", 908, 1418, 9934, 501, 9960); - tester.assertEmbedded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960); - } - - @Test - public void testDenseTensorEncoding() { - var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - tester.assertEmbedded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]"); - tester.assertEmbedded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]"); - tester.assertEmbedded("hello, world!", "tensor(d[2])", "[908,1418]"); + public void testEnglishEmbedding() { + var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build()); + tester.assertEmbedded("hello, world!", "tensor(d[10])", 908, 1418, 9934, 501, 9960); + tester.assertEmbedded("Hello, world!", "tensor(d[10])", 9912, 0, 6595, 9934, 501, 9960); + tester.assertEmbedded("hello, world!", "tensor(d[2])", 908, 1418, 9934, 501, 9960); } @Test public void testNoCollapse() { - var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder() - .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) - .setCollapseUnknowns(false)); + var builder = new SentencePieceEmbedder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setCollapseUnknowns(false); + var tester = new EmbedderTester(builder.build()); tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); } @Test public void testHighestScore() { - var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder() - .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) - .setScoring(Scoring.highestScore)); + var builder = new SentencePieceEmbedder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setScoring(Scoring.highestScore); + var tester = new EmbedderTester(builder.build()); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); tester.assertSegmented("hel", "▁h", "el"); @@ -74,7 +70,7 @@ public class SentencePieceTest { SentencePieceEmbedder.Builder builder = new SentencePieceEmbedder.Builder(); builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - var tester = new SentencePieceTester(builder); + var tester = new EmbedderTester(builder.build()); tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java deleted file mode 100644 index 4dae53c60df..00000000000 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -// - -package com.yahoo.language.sentencepiece; - -import com.yahoo.language.Language; -import com.yahoo.language.process.Embedder; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.nio.file.Path; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - -class SentencePieceTester { - - private final SentencePieceEmbedder embedder; - - public SentencePieceTester(Path model) { - this(new SentencePieceEmbedder.Builder().addDefaultModel(model)); - } - - public SentencePieceTester(SentencePieceEmbedder.Builder builder) { - this(builder.build()); - } - - public SentencePieceTester(SentencePieceEmbedder embedder) { - this.embedder = embedder; - } - - public void assertEmbedded(String input, Integer... expectedCodes) { - assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray()); - } - - public void assertEmbedded(String input, String tensorType, String tensor) { - TensorType type = TensorType.fromSpec(tensorType); - Tensor expected = Tensor.from(type, tensor); - assertEquals(expected, embedder.embed(input, new Embedder.Context("test"), type)); - } - - public void assertSegmented(String input, String... expectedSegments) { - assertSegmented(Language.UNKNOWN, input, expectedSegments); - } - - public void assertSegmented(Language language, String input, String... expectedSegments) { - assertArrayEquals(expectedSegments, embedder.segment(input, language).toArray()); - } - -} |