diff options
Diffstat (limited to 'linguistics-components')
-rw-r--r-- | linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java | 5 | ||||
-rw-r--r-- | linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java | 12 |
2 files changed, 13 insertions, 4 deletions
diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java index 9599e60e556..638efba2480 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java +++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.Arrays; +import java.util.List; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -53,7 +54,9 @@ public class EmbedderTester { } public void assertSegmented(Language language, String input, String... expectedSegments) { - assertArrayEquals(expectedSegments, ((Segmenter)embedder).segment(input, language).toArray()); + List<String> segments = ((Segmenter)embedder).segment(input, language); + assertArrayEquals("Actual segments: " + segments, + expectedSegments, ((Segmenter)embedder).segment(input, language).toArray()); } } diff --git a/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java index 4cbfe541327..13e0cbce10d 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java @@ -15,13 +15,19 @@ public class WordPieceEmbedderTest { private static final String vocabulary = "src/test/models/wordpiece/bert-base-uncased-vocab.txt"; @Test - public void testWordPieceEmbedder() { + public void testWordPieceSegmentation() { + var tester = new EmbedderTester(new WordPieceEmbedder.Builder(vocabulary).build()); + tester.assertSegmented("what was the impact of the manhattan project", + "what", "was", "the", "impact", "of", "the", "manhattan", "project"); + tester.assertSegmented("overcommunication", "over", "##com", "##mun", "##ication"); + } + + @Test + public void testWordPieceEmbedding() { var tester = new EmbedderTester(new WordPieceEmbedder.Builder(vocabulary).build()); tester.assertEmbedded("what was the impact of the manhattan project", "tensor(x[8])", 2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622); - tester.assertSegmented("what was the impact of the manhattan project", - "what", "was", "the", "impact", "of", "the", "manhattan", "project"); } @Test |