diff options
Diffstat (limited to 'linguistics-components/src')
12 files changed, 184 insertions, 162 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index ff7f4ae42bc..31964eac514 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -10,6 +10,7 @@ import com.yahoo.language.process.Segmenter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.io.File; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; @@ -136,13 +137,16 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { return b.toString(); } - public static class Builder { + public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); private boolean collapseUnknowns = true; private Scoring scoring = Scoring.fewestSegments; - public Builder() { + public Builder() {} + + public Builder(String defaultModelFile) { + addDefaultModel(new File(defaultModelFile).toPath()); } private Builder(SentencePieceConfig config) { diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java index 54f37d597ce..ce996b85313 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java @@ -1,5 +1,5 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.bert; +package com.yahoo.language.wordpiece; import com.yahoo.collections.Tuple2; import com.yahoo.language.Language; @@ -23,7 +23,7 @@ import java.util.TreeMap; import java.util.stream.Collectors; /** - * A BERT embedder "model" - just a vocabulary of strings with a fixed id (index). + * A WordPiece embedder "model" - just a vocabulary of strings with a fixed id (index). * * Adapted from * https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java @@ -34,12 +34,14 @@ import java.util.stream.Collectors; */ class Model { - final Path source; - final Language language; + private final String subwordPrefix; + private final Path source; + private final Language language; private final NavigableMap<String, Integer> vocabulary; private final Map<Integer, String> tokenId2Token; - Model(Language language, Path path) { + Model(String subwordPrefix, Language language, Path path) { + this.subwordPrefix = subwordPrefix; this.source = path; this.language = language; @@ -56,23 +58,25 @@ class Model { } } catch (IOException e) { - throw new IllegalArgumentException("Could not read a BERT model from " + path, e); + throw new IllegalArgumentException("Could not read a WordPiece model from " + path, e); } } - public List<Integer> embed(String text, Tokenizer tokenizer) { + Language language() { return language; } + + List<Integer> embed(String text, Tokenizer tokenizer) { List<Integer> ids = new ArrayList<>(); text = text.toLowerCase(); for (Token t : tokenizer.tokenize(text, language, StemMode.NONE, true)) { String originalToken = t.getTokenString(); String candidate = originalToken; int count = 0; - while (candidate.length() > 0 && !"##".equals(candidate)) { + while (candidate.length() > 0 && !candidate.equals(subwordPrefix)) { Tuple2<String, Integer> entry = findLongestSubstring(candidate); if (entry == null) break; ids.add(entry.second); - candidate = "##" + candidate.substring(entry.first.length()); + candidate = subwordPrefix + candidate.substring(entry.first.length()); if (count++ > originalToken.length()) break; } } @@ -80,7 +84,7 @@ class Model { return ids; } - public List<String> segment(String text, Tokenizer tokenizer) { + List<String> segment(String text, Tokenizer tokenizer) { return embed(text, tokenizer).stream().map(tokenId -> tokenId2Token.get(tokenId)).collect(Collectors.toList()); } @@ -102,4 +106,9 @@ class Model { return new Tuple2<>(longestSubstring, id); } + @Override + public String toString() { + return "WordPiece model for " + language + ": '" + source + "'"; + } + } diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java index c2b19391e74..08de05f351a 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java @@ -1,5 +1,5 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.bert; +package com.yahoo.language.wordpiece; import com.google.inject.Inject; import com.yahoo.language.tools.Embed; @@ -10,7 +10,9 @@ import com.yahoo.language.process.Tokenizer; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.language.wordpiece.WordPieceConfig; +import java.io.File; import java.nio.file.Path; import java.util.EnumMap; import java.util.List; @@ -18,35 +20,37 @@ import java.util.Map; import java.util.stream.Collectors; /** - * An embedder to use with BERT models: Text is tokenized into tokens from a configured vocabulary, + * An implementation of the WordPiece embedder, usually used with BERT models, + * see https://arxiv.org/pdf/1609.08144v2.pdf + * Text is tokenized into tokens from a configured vocabulary, * and optionally returned as a 1-d dense tensor of token ids. * * @author bratseth */ -public class BertEmbedder implements Embedder, Segmenter { +public class WordPieceEmbedder implements Embedder, Segmenter { private final Map<Language, Model> models; private final Tokenizer tokenizer; @Inject - public BertEmbedder(BertConfig config) { + public WordPieceEmbedder(WordPieceConfig config) { this(new Builder(config)); } - private BertEmbedder(Builder builder) { + private WordPieceEmbedder(Builder builder) { super(); this.tokenizer = new SimpleLinguistics().getTokenizer(); // always just split on spaces etc. and lowercase models = builder.getModels().entrySet() .stream() - .map(e -> new Model(e.getKey(), e.getValue())) - .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); + .map(e -> new Model(builder.getSubwordPrefix(), e.getKey(), e.getValue())) + .collect(Collectors.toUnmodifiableMap(m -> m.language(), m -> m)); if (models.isEmpty()) - throw new IllegalArgumentException("BertEmbedder requires at least one model configured"); + throw new IllegalArgumentException("WordPieceEmbedder requires at least one model configured"); } /** - * Segments the given text into token segments from the BERT vocabulary. + * Segments the given text into token segments from the WordPiece vocabulary. * * @param text the text to segment. The text should be of a language using space-separated words. * @return the list of zero or more token ids resulting from segmenting the input text @@ -57,7 +61,7 @@ public class BertEmbedder implements Embedder, Segmenter { } /** - * Segments the given text into token segments from the BERT vocabulary and returns the token ids. + * Segments the given text into token segments from the WordPiece vocabulary and returns the token ids. * * @param text the text to segment. The text should be of a language using space-separated words. * @param context the context which specifies the language used to select a model @@ -90,21 +94,33 @@ public class BertEmbedder implements Embedder, Segmenter { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); if (models.containsKey(language)) return models.get(language); - throw new IllegalArgumentException("No BERT model for language " + language + " is configured"); + throw new IllegalArgumentException("No WordPiece model for language " + language + " is configured"); } - public static class Builder { + public static final class Builder { + private String subwordPrefix = "##"; private final Map<Language, Path> models = new EnumMap<>(Language.class); - public Builder() { + public Builder() {} + + public Builder(String defaultModelFile) { + addDefaultModel(new File(defaultModelFile).toPath()); } - private Builder(BertConfig config) { - for (BertConfig.Model model : config.model()) + private Builder(WordPieceConfig config) { + this.subwordPrefix = config.subwordPrefix(); + for (WordPieceConfig.Model model : config.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); } + public Builder setSubwordPrefix(String prefix) { + this.subwordPrefix = subwordPrefix; + return this; + } + + public String getSubwordPrefix() { return subwordPrefix; } + public void addModel(Language language, Path model) { models.put(language, model); } @@ -113,16 +129,16 @@ public class BertEmbedder implements Embedder, Segmenter { * Adds the model that will be used if the language is unknown, OR only one model is specified. * The same as addModel(Language.UNKNOWN, model). */ - public BertEmbedder.Builder addDefaultModel(Path model) { + public WordPieceEmbedder.Builder addDefaultModel(Path model) { addModel(Language.UNKNOWN, model); return this; } public Map<Language, Path> getModels() { return models; } - public BertEmbedder build() { + public WordPieceEmbedder build() { if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); - return new BertEmbedder(this); + return new WordPieceEmbedder(this); } } diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java index e3f612f4114..0bbb6f001f5 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage @PublicApi -package com.yahoo.language.bert; +package com.yahoo.language.wordpiece; import com.yahoo.api.annotations.PublicApi; import com.yahoo.osgi.annotation.ExportPackage; diff --git a/linguistics-components/src/main/resources/configdefinitions/language.bert.bert.def b/linguistics-components/src/main/resources/configdefinitions/language.wordpiece.word-piece.def index 86d338758d0..08592250eb5 100644 --- a/linguistics-components/src/main/resources/configdefinitions/language.bert.bert.def +++ b/linguistics-components/src/main/resources/configdefinitions/language.wordpiece.word-piece.def @@ -1,8 +1,11 @@ # Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -# Configures com.yahoo.language.bert.BertEmbedder +# Configures com.yahoo.language.wordpiece.WordPieceEmbedder -namespace=language.bert +namespace=language.wordpiece + +# The prefix to prepend to subword tokens +subwordPrefix string default="##" # The language a model is for, one of the language tags in com.yahoo.language.Language. # Use "unknown" for a model to be used for any language (i.e by default). diff --git a/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java b/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java deleted file mode 100644 index 1bc25e0d217..00000000000 --- a/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.bert; - -import com.yahoo.config.FileReference; -import com.yahoo.language.Language; -import com.yahoo.language.process.Embedder; -import com.yahoo.language.simple.SimpleLinguistics; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import org.junit.Test; - -import java.io.File; -import java.util.List; - -import static org.junit.Assert.assertEquals; - -/** - * Tests the BERT embedder - * - * @author bratseth - */ -public class BertEmbedderTest { - - private static final String vocabulary = "src/test/models/bert/bert-base-uncased-vocab.txt"; - - @Test - public void testBertEmbedder() { - var embedder = new BertEmbedder.Builder().addDefaultModel(new File(vocabulary).toPath()).build(); - var expectedTokenIds = List.of(2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622); - assertEquals(expectedTokenIds, embedder.embed("what was the impact of the manhattan project", - new Embedder.Context("destination"))); - - var expectedTokens = List.of("what", "was", "the", "impact", "of", "the", "manhattan", "project"); - assertEquals(expectedTokens, embedder.segment("what was the impact of the manhattan project", - Language.ENGLISH)); - - var expectedDenseTensor = Tensor.from("tensor(x[8]):" + expectedTokenIds); - assertEquals(expectedDenseTensor, embedder.embed("what was the impact of the manhattan project", - new Embedder.Context("destination"), - expectedDenseTensor.type())); - } - - @Test - public void testBertEmbedderConfiguration() { - var config = new BertConfig.Builder().model(new BertConfig.Model.Builder().language("unknown") - .path(new FileReference(vocabulary))) - .build(); - var embedder = new BertEmbedder(config); - var expectedTokenIds = List.of(2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622); - assertEquals(expectedTokenIds, embedder.embed("what was the impact of the manhattan project", - new Embedder.Context("destination"))); - } - -} diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java index 1ed2271f774..19cb2267655 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java @@ -4,6 +4,7 @@ package com.yahoo.language.sentencepiece; import com.yahoo.config.FileReference; import com.yahoo.language.Language; +import com.yahoo.language.tools.EmbedderTester; import org.junit.Test; /** @@ -15,7 +16,7 @@ public class SentencePieceConfigurationTest { public void testEnglishTokenization() { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); } @@ -25,7 +26,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); b.collapseUnknowns(false); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); } @@ -34,7 +35,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); b.scoring(SentencePieceConfig.Scoring.highestScore); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("hello", "▁h", "el", "lo"); } @@ -43,7 +44,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b); addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index 8b3e2988c43..2fbafb23485 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -3,6 +3,7 @@ package com.yahoo.language.sentencepiece; import com.yahoo.language.Language; +import com.yahoo.language.tools.EmbedderTester; import org.junit.Test; import java.io.File; @@ -13,8 +14,8 @@ import java.io.File; public class SentencePieceTest { @Test - public void testEnglishTokenization() { - var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + public void testEnglishSegmenting() { + var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build()); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); tester.assertSegmented("hel", "▁hel"); @@ -36,33 +37,28 @@ public class SentencePieceTest { } @Test - public void testIntegerListEncoding() { - var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - tester.assertEmbedded("hello, world!", 908, 1418, 9934, 501, 9960); - tester.assertEmbedded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960); - } - - @Test - public void testDenseTensorEncoding() { - var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - tester.assertEmbedded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]"); - tester.assertEmbedded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]"); - tester.assertEmbedded("hello, world!", "tensor(d[2])", "[908,1418]"); + public void testEnglishEmbedding() { + var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build()); + tester.assertEmbedded("hello, world!", "tensor(d[10])", 908, 1418, 9934, 501, 9960); + tester.assertEmbedded("Hello, world!", "tensor(d[10])", 9912, 0, 6595, 9934, 501, 9960); + tester.assertEmbedded("hello, world!", "tensor(d[2])", 908, 1418, 9934, 501, 9960); } @Test public void testNoCollapse() { - var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder() - .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) - .setCollapseUnknowns(false)); + var builder = new SentencePieceEmbedder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setCollapseUnknowns(false); + var tester = new EmbedderTester(builder.build()); tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); } @Test public void testHighestScore() { - var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder() - .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) - .setScoring(Scoring.highestScore)); + var builder = new SentencePieceEmbedder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setScoring(Scoring.highestScore); + var tester = new EmbedderTester(builder.build()); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); tester.assertSegmented("hel", "▁h", "el"); @@ -74,7 +70,7 @@ public class SentencePieceTest { SentencePieceEmbedder.Builder builder = new SentencePieceEmbedder.Builder(); builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - var tester = new SentencePieceTester(builder); + var tester = new EmbedderTester(builder.build()); tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java deleted file mode 100644 index 4dae53c60df..00000000000 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -// - -package com.yahoo.language.sentencepiece; - -import com.yahoo.language.Language; -import com.yahoo.language.process.Embedder; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.nio.file.Path; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - -class SentencePieceTester { - - private final SentencePieceEmbedder embedder; - - public SentencePieceTester(Path model) { - this(new SentencePieceEmbedder.Builder().addDefaultModel(model)); - } - - public SentencePieceTester(SentencePieceEmbedder.Builder builder) { - this(builder.build()); - } - - public SentencePieceTester(SentencePieceEmbedder embedder) { - this.embedder = embedder; - } - - public void assertEmbedded(String input, Integer... expectedCodes) { - assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray()); - } - - public void assertEmbedded(String input, String tensorType, String tensor) { - TensorType type = TensorType.fromSpec(tensorType); - Tensor expected = Tensor.from(type, tensor); - assertEquals(expected, embedder.embed(input, new Embedder.Context("test"), type)); - } - - public void assertSegmented(String input, String... expectedSegments) { - assertSegmented(Language.UNKNOWN, input, expectedSegments); - } - - public void assertSegmented(Language language, String input, String... expectedSegments) { - assertArrayEquals(expectedSegments, embedder.segment(input, language).toArray()); - } - -} diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java new file mode 100644 index 00000000000..9599e60e556 --- /dev/null +++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java @@ -0,0 +1,59 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.tools; + +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Segmenter; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.Arrays; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tester of embedders. + * + * @author bratseth + */ +public class EmbedderTester { + + private final Embedder embedder; + + public EmbedderTester(Embedder embedder) { + this.embedder = embedder; + } + + /** + * Tests both embedding to a list of id's and encoding the same ids to a vector of the given type. + * + * @param expectedCodes all the expected codes of the given input, not including any trailing 0-paddings + * required for the tensor only + */ + public void assertEmbedded(String input, String tensorType, Integer... expectedCodes) { + TensorType type = TensorType.fromSpec(tensorType); + assertEquals(1, type.dimensions().size()); + assertTrue(type.dimensions().get(0).isIndexed()); + + int tensorSize = type.dimensions().get(0).size().get().intValue(); + + assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray()); + + var builder = Tensor.Builder.of(type); + for (int i = 0; i < tensorSize; i++) + builder.cell(i < expectedCodes.length ? expectedCodes[i] : 0, i); + assertEquals(builder.build(), embedder.embed(input, new Embedder.Context("destination"), type)); + } + + public void assertSegmented(String input, String... expectedSegments) { + assertSegmented(Language.UNKNOWN, input, expectedSegments); + } + + public void assertSegmented(Language language, String input, String... expectedSegments) { + assertArrayEquals(expectedSegments, ((Segmenter)embedder).segment(input, language).toArray()); + } + +} diff --git a/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java new file mode 100644 index 00000000000..4cbfe541327 --- /dev/null +++ b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java @@ -0,0 +1,38 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.wordpiece; + +import com.yahoo.config.FileReference; +import com.yahoo.language.tools.EmbedderTester; +import org.junit.Test; + +/** + * Tests the WordPiece embedder + * + * @author bratseth + */ +public class WordPieceEmbedderTest { + + private static final String vocabulary = "src/test/models/wordpiece/bert-base-uncased-vocab.txt"; + + @Test + public void testWordPieceEmbedder() { + var tester = new EmbedderTester(new WordPieceEmbedder.Builder(vocabulary).build()); + tester.assertEmbedded("what was the impact of the manhattan project", + "tensor(x[8])", + 2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622); + tester.assertSegmented("what was the impact of the manhattan project", + "what", "was", "the", "impact", "of", "the", "manhattan", "project"); + } + + @Test + public void testWordPieceEmbedderConfiguration() { + var config = new WordPieceConfig.Builder().model(new WordPieceConfig.Model.Builder() + .language("unknown") + .path(new FileReference(vocabulary))) + .build(); + var tester = new EmbedderTester(new WordPieceEmbedder(config)); + tester.assertSegmented("what was the impact of the manhattan project", + "what", "was", "the", "impact", "of", "the", "manhattan", "project"); + } + +} diff --git a/linguistics-components/src/test/models/bert/bert-base-uncased-vocab.txt b/linguistics-components/src/test/models/wordpiece/bert-base-uncased-vocab.txt index fb140275c15..fb140275c15 100644 --- a/linguistics-components/src/test/models/bert/bert-base-uncased-vocab.txt +++ b/linguistics-components/src/test/models/wordpiece/bert-base-uncased-vocab.txt |