summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/test/java/com/yahoo
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-12-17 12:41:17 +0100
committerJon Bratseth <bratseth@gmail.com>2021-12-17 12:41:17 +0100
commit601b117281b74a578126a0f3effead55bc79c680 (patch)
tree29619184a8459763cc024b23e74960e6c9ec7f81 /linguistics-components/src/test/java/com/yahoo
parent767cb63af0f530605180f5438767406e1db27520 (diff)
BERT -> WordPiece, make subword prefix configurable
Diffstat (limited to 'linguistics-components/src/test/java/com/yahoo')
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java54
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java9
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java38
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java50
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java59
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java38
6 files changed, 119 insertions, 129 deletions
diff --git a/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java b/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java
deleted file mode 100644
index 1bc25e0d217..00000000000
--- a/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java
+++ /dev/null
@@ -1,54 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.language.bert;
-
-import com.yahoo.config.FileReference;
-import com.yahoo.language.Language;
-import com.yahoo.language.process.Embedder;
-import com.yahoo.language.simple.SimpleLinguistics;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-import org.junit.Test;
-
-import java.io.File;
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-
-/**
- * Tests the BERT embedder
- *
- * @author bratseth
- */
-public class BertEmbedderTest {
-
- private static final String vocabulary = "src/test/models/bert/bert-base-uncased-vocab.txt";
-
- @Test
- public void testBertEmbedder() {
- var embedder = new BertEmbedder.Builder().addDefaultModel(new File(vocabulary).toPath()).build();
- var expectedTokenIds = List.of(2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622);
- assertEquals(expectedTokenIds, embedder.embed("what was the impact of the manhattan project",
- new Embedder.Context("destination")));
-
- var expectedTokens = List.of("what", "was", "the", "impact", "of", "the", "manhattan", "project");
- assertEquals(expectedTokens, embedder.segment("what was the impact of the manhattan project",
- Language.ENGLISH));
-
- var expectedDenseTensor = Tensor.from("tensor(x[8]):" + expectedTokenIds);
- assertEquals(expectedDenseTensor, embedder.embed("what was the impact of the manhattan project",
- new Embedder.Context("destination"),
- expectedDenseTensor.type()));
- }
-
- @Test
- public void testBertEmbedderConfiguration() {
- var config = new BertConfig.Builder().model(new BertConfig.Model.Builder().language("unknown")
- .path(new FileReference(vocabulary)))
- .build();
- var embedder = new BertEmbedder(config);
- var expectedTokenIds = List.of(2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622);
- assertEquals(expectedTokenIds, embedder.embed("what was the impact of the manhattan project",
- new Embedder.Context("destination")));
- }
-
-}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
index 1ed2271f774..19cb2267655 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
@@ -4,6 +4,7 @@ package com.yahoo.language.sentencepiece;
import com.yahoo.config.FileReference;
import com.yahoo.language.Language;
+import com.yahoo.language.tools.EmbedderTester;
import org.junit.Test;
/**
@@ -15,7 +16,7 @@ public class SentencePieceConfigurationTest {
public void testEnglishTokenization() {
var b = new SentencePieceConfig.Builder();
addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
- var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
+ var tester = new EmbedderTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence");
tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo");
}
@@ -25,7 +26,7 @@ public class SentencePieceConfigurationTest {
var b = new SentencePieceConfig.Builder();
addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
b.collapseUnknowns(false);
- var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
+ var tester = new EmbedderTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
}
@@ -34,7 +35,7 @@ public class SentencePieceConfigurationTest {
var b = new SentencePieceConfig.Builder();
addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
b.scoring(SentencePieceConfig.Scoring.highestScore);
- var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
+ var tester = new EmbedderTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented("hello", "▁h", "el", "lo");
}
@@ -43,7 +44,7 @@ public class SentencePieceConfigurationTest {
var b = new SentencePieceConfig.Builder();
addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b);
addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
- var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
+ var tester = new EmbedderTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index 8b3e2988c43..2fbafb23485 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -3,6 +3,7 @@
package com.yahoo.language.sentencepiece;
import com.yahoo.language.Language;
+import com.yahoo.language.tools.EmbedderTester;
import org.junit.Test;
import java.io.File;
@@ -13,8 +14,8 @@ import java.io.File;
public class SentencePieceTest {
@Test
- public void testEnglishTokenization() {
- var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ public void testEnglishSegmenting() {
+ var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build());
tester.assertSegmented("h", "▁h");
tester.assertSegmented("he", "▁he");
tester.assertSegmented("hel", "▁hel");
@@ -36,33 +37,28 @@ public class SentencePieceTest {
}
@Test
- public void testIntegerListEncoding() {
- var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
- tester.assertEmbedded("hello, world!", 908, 1418, 9934, 501, 9960);
- tester.assertEmbedded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960);
- }
-
- @Test
- public void testDenseTensorEncoding() {
- var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
- tester.assertEmbedded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]");
- tester.assertEmbedded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]");
- tester.assertEmbedded("hello, world!", "tensor(d[2])", "[908,1418]");
+ public void testEnglishEmbedding() {
+ var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build());
+ tester.assertEmbedded("hello, world!", "tensor(d[10])", 908, 1418, 9934, 501, 9960);
+ tester.assertEmbedded("Hello, world!", "tensor(d[10])", 9912, 0, 6595, 9934, 501, 9960);
+ tester.assertEmbedded("hello, world!", "tensor(d[2])", 908, 1418, 9934, 501, 9960);
}
@Test
public void testNoCollapse() {
- var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder()
- .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
- .setCollapseUnknowns(false));
+ var builder = new SentencePieceEmbedder.Builder()
+ .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
+ .setCollapseUnknowns(false);
+ var tester = new EmbedderTester(builder.build());
tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
}
@Test
public void testHighestScore() {
- var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder()
- .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
- .setScoring(Scoring.highestScore));
+ var builder = new SentencePieceEmbedder.Builder()
+ .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
+ .setScoring(Scoring.highestScore);
+ var tester = new EmbedderTester(builder.build());
tester.assertSegmented("h", "▁h");
tester.assertSegmented("he", "▁he");
tester.assertSegmented("hel", "▁h", "el");
@@ -74,7 +70,7 @@ public class SentencePieceTest {
SentencePieceEmbedder.Builder builder = new SentencePieceEmbedder.Builder();
builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath());
builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
- var tester = new SentencePieceTester(builder);
+ var tester = new EmbedderTester(builder.build());
tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
deleted file mode 100644
index 4dae53c60df..00000000000
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-//
-
-package com.yahoo.language.sentencepiece;
-
-import com.yahoo.language.Language;
-import com.yahoo.language.process.Embedder;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.nio.file.Path;
-
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-
-class SentencePieceTester {
-
- private final SentencePieceEmbedder embedder;
-
- public SentencePieceTester(Path model) {
- this(new SentencePieceEmbedder.Builder().addDefaultModel(model));
- }
-
- public SentencePieceTester(SentencePieceEmbedder.Builder builder) {
- this(builder.build());
- }
-
- public SentencePieceTester(SentencePieceEmbedder embedder) {
- this.embedder = embedder;
- }
-
- public void assertEmbedded(String input, Integer... expectedCodes) {
- assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray());
- }
-
- public void assertEmbedded(String input, String tensorType, String tensor) {
- TensorType type = TensorType.fromSpec(tensorType);
- Tensor expected = Tensor.from(type, tensor);
- assertEquals(expected, embedder.embed(input, new Embedder.Context("test"), type));
- }
-
- public void assertSegmented(String input, String... expectedSegments) {
- assertSegmented(Language.UNKNOWN, input, expectedSegments);
- }
-
- public void assertSegmented(Language language, String input, String... expectedSegments) {
- assertArrayEquals(expectedSegments, embedder.segment(input, language).toArray());
- }
-
-}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
new file mode 100644
index 00000000000..9599e60e556
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
@@ -0,0 +1,59 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.tools;
+
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.language.process.Segmenter;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tester of embedders.
+ *
+ * @author bratseth
+ */
+public class EmbedderTester {
+
+ private final Embedder embedder;
+
+ public EmbedderTester(Embedder embedder) {
+ this.embedder = embedder;
+ }
+
+ /**
+ * Tests both embedding to a list of id's and encoding the same ids to a vector of the given type.
+ *
+ * @param expectedCodes all the expected codes of the given input, not including any trailing 0-paddings
+ * required for the tensor only
+ */
+ public void assertEmbedded(String input, String tensorType, Integer... expectedCodes) {
+ TensorType type = TensorType.fromSpec(tensorType);
+ assertEquals(1, type.dimensions().size());
+ assertTrue(type.dimensions().get(0).isIndexed());
+
+ int tensorSize = type.dimensions().get(0).size().get().intValue();
+
+ assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray());
+
+ var builder = Tensor.Builder.of(type);
+ for (int i = 0; i < tensorSize; i++)
+ builder.cell(i < expectedCodes.length ? expectedCodes[i] : 0, i);
+ assertEquals(builder.build(), embedder.embed(input, new Embedder.Context("destination"), type));
+ }
+
+ public void assertSegmented(String input, String... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+
+ public void assertSegmented(Language language, String input, String... expectedSegments) {
+ assertArrayEquals(expectedSegments, ((Segmenter)embedder).segment(input, language).toArray());
+ }
+
+}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java
new file mode 100644
index 00000000000..4cbfe541327
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java
@@ -0,0 +1,38 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.wordpiece;
+
+import com.yahoo.config.FileReference;
+import com.yahoo.language.tools.EmbedderTester;
+import org.junit.Test;
+
+/**
+ * Tests the WordPiece embedder
+ *
+ * @author bratseth
+ */
+public class WordPieceEmbedderTest {
+
+ private static final String vocabulary = "src/test/models/wordpiece/bert-base-uncased-vocab.txt";
+
+ @Test
+ public void testWordPieceEmbedder() {
+ var tester = new EmbedderTester(new WordPieceEmbedder.Builder(vocabulary).build());
+ tester.assertEmbedded("what was the impact of the manhattan project",
+ "tensor(x[8])",
+ 2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622);
+ tester.assertSegmented("what was the impact of the manhattan project",
+ "what", "was", "the", "impact", "of", "the", "manhattan", "project");
+ }
+
+ @Test
+ public void testWordPieceEmbedderConfiguration() {
+ var config = new WordPieceConfig.Builder().model(new WordPieceConfig.Model.Builder()
+ .language("unknown")
+ .path(new FileReference(vocabulary)))
+ .build();
+ var tester = new EmbedderTester(new WordPieceEmbedder(config));
+ tester.assertSegmented("what was the impact of the manhattan project",
+ "what", "was", "the", "impact", "of", "the", "manhattan", "project");
+ }
+
+}