diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-12-17 12:41:17 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-12-17 12:41:17 +0100 |
commit | 601b117281b74a578126a0f3effead55bc79c680 (patch) | |
tree | 29619184a8459763cc024b23e74960e6c9ec7f81 /linguistics-components/src/main/java | |
parent | 767cb63af0f530605180f5438767406e1db27520 (diff) |
BERT -> WordPiece, make subword prefix configurable
Diffstat (limited to 'linguistics-components/src/main/java')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java | 8 | ||||
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java (renamed from linguistics-components/src/main/java/com/yahoo/language/bert/Model.java) | 29 | ||||
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java (renamed from linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java) | 52 | ||||
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java (renamed from linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java) | 2 |
4 files changed, 60 insertions, 31 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index ff7f4ae42bc..31964eac514 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -10,6 +10,7 @@ import com.yahoo.language.process.Segmenter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.io.File; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; @@ -136,13 +137,16 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { return b.toString(); } - public static class Builder { + public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); private boolean collapseUnknowns = true; private Scoring scoring = Scoring.fewestSegments; - public Builder() { + public Builder() {} + + public Builder(String defaultModelFile) { + addDefaultModel(new File(defaultModelFile).toPath()); } private Builder(SentencePieceConfig config) { diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java index 54f37d597ce..ce996b85313 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java @@ -1,5 +1,5 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.bert; +package com.yahoo.language.wordpiece; import com.yahoo.collections.Tuple2; import com.yahoo.language.Language; @@ -23,7 +23,7 @@ import java.util.TreeMap; import java.util.stream.Collectors; /** - * A BERT embedder "model" - just a vocabulary of strings with a fixed id (index). + * A WordPiece embedder "model" - just a vocabulary of strings with a fixed id (index). * * Adapted from * https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java @@ -34,12 +34,14 @@ import java.util.stream.Collectors; */ class Model { - final Path source; - final Language language; + private final String subwordPrefix; + private final Path source; + private final Language language; private final NavigableMap<String, Integer> vocabulary; private final Map<Integer, String> tokenId2Token; - Model(Language language, Path path) { + Model(String subwordPrefix, Language language, Path path) { + this.subwordPrefix = subwordPrefix; this.source = path; this.language = language; @@ -56,23 +58,25 @@ class Model { } } catch (IOException e) { - throw new IllegalArgumentException("Could not read a BERT model from " + path, e); + throw new IllegalArgumentException("Could not read a WordPiece model from " + path, e); } } - public List<Integer> embed(String text, Tokenizer tokenizer) { + Language language() { return language; } + + List<Integer> embed(String text, Tokenizer tokenizer) { List<Integer> ids = new ArrayList<>(); text = text.toLowerCase(); for (Token t : tokenizer.tokenize(text, language, StemMode.NONE, true)) { String originalToken = t.getTokenString(); String candidate = originalToken; int count = 0; - while (candidate.length() > 0 && !"##".equals(candidate)) { + while (candidate.length() > 0 && !candidate.equals(subwordPrefix)) { Tuple2<String, Integer> entry = findLongestSubstring(candidate); if (entry == null) break; ids.add(entry.second); - candidate = "##" + candidate.substring(entry.first.length()); + candidate = subwordPrefix + candidate.substring(entry.first.length()); if (count++ > originalToken.length()) break; } } @@ -80,7 +84,7 @@ class Model { return ids; } - public List<String> segment(String text, Tokenizer tokenizer) { + List<String> segment(String text, Tokenizer tokenizer) { return embed(text, tokenizer).stream().map(tokenId -> tokenId2Token.get(tokenId)).collect(Collectors.toList()); } @@ -102,4 +106,9 @@ class Model { return new Tuple2<>(longestSubstring, id); } + @Override + public String toString() { + return "WordPiece model for " + language + ": '" + source + "'"; + } + } diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java index c2b19391e74..08de05f351a 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java @@ -1,5 +1,5 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.bert; +package com.yahoo.language.wordpiece; import com.google.inject.Inject; import com.yahoo.language.tools.Embed; @@ -10,7 +10,9 @@ import com.yahoo.language.process.Tokenizer; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.language.wordpiece.WordPieceConfig; +import java.io.File; import java.nio.file.Path; import java.util.EnumMap; import java.util.List; @@ -18,35 +20,37 @@ import java.util.Map; import java.util.stream.Collectors; /** - * An embedder to use with BERT models: Text is tokenized into tokens from a configured vocabulary, + * An implementation of the WordPiece embedder, usually used with BERT models, + * see https://arxiv.org/pdf/1609.08144v2.pdf + * Text is tokenized into tokens from a configured vocabulary, * and optionally returned as a 1-d dense tensor of token ids. * * @author bratseth */ -public class BertEmbedder implements Embedder, Segmenter { +public class WordPieceEmbedder implements Embedder, Segmenter { private final Map<Language, Model> models; private final Tokenizer tokenizer; @Inject - public BertEmbedder(BertConfig config) { + public WordPieceEmbedder(WordPieceConfig config) { this(new Builder(config)); } - private BertEmbedder(Builder builder) { + private WordPieceEmbedder(Builder builder) { super(); this.tokenizer = new SimpleLinguistics().getTokenizer(); // always just split on spaces etc. and lowercase models = builder.getModels().entrySet() .stream() - .map(e -> new Model(e.getKey(), e.getValue())) - .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); + .map(e -> new Model(builder.getSubwordPrefix(), e.getKey(), e.getValue())) + .collect(Collectors.toUnmodifiableMap(m -> m.language(), m -> m)); if (models.isEmpty()) - throw new IllegalArgumentException("BertEmbedder requires at least one model configured"); + throw new IllegalArgumentException("WordPieceEmbedder requires at least one model configured"); } /** - * Segments the given text into token segments from the BERT vocabulary. + * Segments the given text into token segments from the WordPiece vocabulary. * * @param text the text to segment. The text should be of a language using space-separated words. * @return the list of zero or more token ids resulting from segmenting the input text @@ -57,7 +61,7 @@ public class BertEmbedder implements Embedder, Segmenter { } /** - * Segments the given text into token segments from the BERT vocabulary and returns the token ids. + * Segments the given text into token segments from the WordPiece vocabulary and returns the token ids. * * @param text the text to segment. The text should be of a language using space-separated words. * @param context the context which specifies the language used to select a model @@ -90,21 +94,33 @@ public class BertEmbedder implements Embedder, Segmenter { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); if (models.containsKey(language)) return models.get(language); - throw new IllegalArgumentException("No BERT model for language " + language + " is configured"); + throw new IllegalArgumentException("No WordPiece model for language " + language + " is configured"); } - public static class Builder { + public static final class Builder { + private String subwordPrefix = "##"; private final Map<Language, Path> models = new EnumMap<>(Language.class); - public Builder() { + public Builder() {} + + public Builder(String defaultModelFile) { + addDefaultModel(new File(defaultModelFile).toPath()); } - private Builder(BertConfig config) { - for (BertConfig.Model model : config.model()) + private Builder(WordPieceConfig config) { + this.subwordPrefix = config.subwordPrefix(); + for (WordPieceConfig.Model model : config.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); } + public Builder setSubwordPrefix(String prefix) { + this.subwordPrefix = subwordPrefix; + return this; + } + + public String getSubwordPrefix() { return subwordPrefix; } + public void addModel(Language language, Path model) { models.put(language, model); } @@ -113,16 +129,16 @@ public class BertEmbedder implements Embedder, Segmenter { * Adds the model that will be used if the language is unknown, OR only one model is specified. * The same as addModel(Language.UNKNOWN, model). */ - public BertEmbedder.Builder addDefaultModel(Path model) { + public WordPieceEmbedder.Builder addDefaultModel(Path model) { addModel(Language.UNKNOWN, model); return this; } public Map<Language, Path> getModels() { return models; } - public BertEmbedder build() { + public WordPieceEmbedder build() { if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); - return new BertEmbedder(this); + return new WordPieceEmbedder(this); } } diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java index e3f612f4114..0bbb6f001f5 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage @PublicApi -package com.yahoo.language.bert; +package com.yahoo.language.wordpiece; import com.yahoo.api.annotations.PublicApi; import com.yahoo.osgi.annotation.ExportPackage; |