diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-12-16 18:35:11 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-12-16 18:35:11 +0100 |
commit | 767cb63af0f530605180f5438767406e1db27520 (patch) | |
tree | c0ea9e8ec4ded2dea6064a45334e6f8a1408f7b8 /linguistics-components/src/main/java/com | |
parent | 1eefb9854bcd7ca264889239b32e7fc8c8830672 (diff) |
Add a BERT embedder
Diffstat (limited to 'linguistics-components/src/main/java/com')
5 files changed, 294 insertions, 37 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java new file mode 100644 index 00000000000..c2b19391e74 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java @@ -0,0 +1,131 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.bert; + +import com.google.inject.Inject; +import com.yahoo.language.tools.Embed; +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Segmenter; +import com.yahoo.language.process.Tokenizer; +import com.yahoo.language.simple.SimpleLinguistics; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.nio.file.Path; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An embedder to use with BERT models: Text is tokenized into tokens from a configured vocabulary, + * and optionally returned as a 1-d dense tensor of token ids. + * + * @author bratseth + */ +public class BertEmbedder implements Embedder, Segmenter { + + private final Map<Language, Model> models; + + private final Tokenizer tokenizer; + + @Inject + public BertEmbedder(BertConfig config) { + this(new Builder(config)); + } + + private BertEmbedder(Builder builder) { + super(); + this.tokenizer = new SimpleLinguistics().getTokenizer(); // always just split on spaces etc. and lowercase + models = builder.getModels().entrySet() + .stream() + .map(e -> new Model(e.getKey(), e.getValue())) + .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); + if (models.isEmpty()) + throw new IllegalArgumentException("BertEmbedder requires at least one model configured"); + } + + /** + * Segments the given text into token segments from the BERT vocabulary. + * + * @param text the text to segment. The text should be of a language using space-separated words. + * @return the list of zero or more token ids resulting from segmenting the input text + */ + @Override + public List<String> segment(String text, Language language) { + return resolveModelFrom(language).segment(text, tokenizer); + } + + /** + * Segments the given text into token segments from the BERT vocabulary and returns the token ids. + * + * @param text the text to segment. The text should be of a language using space-separated words. + * @param context the context which specifies the language used to select a model + * @return the list of zero or more token ids resulting from segmenting the input text + */ + @Override + public List<Integer> embed(String text, Context context) { + return resolveModelFrom(context.getLanguage()).embed(text, tokenizer); + } + + /** + * <p>Embeds text into a tensor.</p> + * + * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order + * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small + * it will be truncated.</p> + * + * <p>If the tensor is any other type IllegalArgumentException is thrown.</p> + * + * @param text the text to segment. The text should be of a language using space-separated words. + * @param context the context which specifies the language used to select a model + * @return the list of zero or more token ids resulting from segmenting the input text + */ + @Override + public Tensor embed(String text, Context context, TensorType type) { + return Embed.asTensor(text, this, context, type); + } + + private Model resolveModelFrom(Language language) { + // Disregard language if there is default model + if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); + if (models.containsKey(language)) return models.get(language); + throw new IllegalArgumentException("No BERT model for language " + language + " is configured"); + } + + public static class Builder { + + private final Map<Language, Path> models = new EnumMap<>(Language.class); + + public Builder() { + } + + private Builder(BertConfig config) { + for (BertConfig.Model model : config.model()) + addModel(Language.fromLanguageTag(model.language()), model.path()); + } + + public void addModel(Language language, Path model) { + models.put(language, model); + } + + /** + * Adds the model that will be used if the language is unknown, OR only one model is specified. + * The same as addModel(Language.UNKNOWN, model). + */ + public BertEmbedder.Builder addDefaultModel(Path model) { + addModel(Language.UNKNOWN, model); + return this; + } + + public Map<Language, Path> getModels() { return models; } + + public BertEmbedder build() { + if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); + return new BertEmbedder(this); + } + + } + +} + diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java b/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java new file mode 100644 index 00000000000..54f37d597ce --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java @@ -0,0 +1,105 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.bert; + +import com.yahoo.collections.Tuple2; +import com.yahoo.language.Language; +import com.yahoo.language.process.StemMode; +import com.yahoo.language.process.Token; +import com.yahoo.language.process.Tokenizer; + +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import java.util.TreeMap; +import java.util.stream.Collectors; + +/** + * A BERT embedder "model" - just a vocabulary of strings with a fixed id (index). + * + * Adapted from + * https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java + * licensed under the Apache License, Version 2.0 + * + * @author bergum + * @author bratseth + */ +class Model { + + final Path source; + final Language language; + private final NavigableMap<String, Integer> vocabulary; + private final Map<Integer, String> tokenId2Token; + + Model(Language language, Path path) { + this.source = path; + this.language = language; + + this.vocabulary = new TreeMap<>(Collections.reverseOrder()); + this.tokenId2Token = new HashMap<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(path.toFile()), + StandardCharsets.UTF_8))) { + String token; + int i = 0; + while ((token = reader.readLine()) != null) { + this.vocabulary.put(token, i); + this.tokenId2Token.put(i, token); + i++; + } + } + catch (IOException e) { + throw new IllegalArgumentException("Could not read a BERT model from " + path, e); + } + + } + + public List<Integer> embed(String text, Tokenizer tokenizer) { + List<Integer> ids = new ArrayList<>(); + text = text.toLowerCase(); + for (Token t : tokenizer.tokenize(text, language, StemMode.NONE, true)) { + String originalToken = t.getTokenString(); + String candidate = originalToken; + int count = 0; + while (candidate.length() > 0 && !"##".equals(candidate)) { + Tuple2<String, Integer> entry = findLongestSubstring(candidate); + if (entry == null) break; + ids.add(entry.second); + candidate = "##" + candidate.substring(entry.first.length()); + if (count++ > originalToken.length()) break; + } + } + + return ids; + } + + public List<String> segment(String text, Tokenizer tokenizer) { + return embed(text, tokenizer).stream().map(tokenId -> tokenId2Token.get(tokenId)).collect(Collectors.toList()); + } + + private Tuple2<String, Integer> findLongestSubstring(String candidate) { + NavigableMap<String, Integer> tailMap = this.vocabulary.tailMap(candidate, true); + if (tailMap.isEmpty()) + return null; + String longestSubstring = tailMap.firstKey(); + Integer id = tailMap.firstEntry().getValue(); + int subStringLength = Math.min(candidate.length(), longestSubstring.length()); + while (!candidate.startsWith(longestSubstring)) { + subStringLength--; + tailMap = tailMap.tailMap(candidate.substring(0, subStringLength), true); + if (tailMap.isEmpty()) + return null; + longestSubstring = tailMap.firstKey(); + id = tailMap.firstEntry().getValue(); + } + return new Tuple2<>(longestSubstring, id); + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java new file mode 100644 index 00000000000..e3f612f4114 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java @@ -0,0 +1,7 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package com.yahoo.language.bert; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index 3afc85300d4..ff7f4ae42bc 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -3,17 +3,17 @@ package com.yahoo.language.sentencepiece; import com.yahoo.api.annotations.Beta; import com.google.inject.Inject; +import com.yahoo.language.tools.Embed; import com.yahoo.language.Language; import com.yahoo.language.process.Embedder; import com.yahoo.language.process.Segmenter; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; +import java.util.EnumMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -38,7 +38,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { } public SentencePieceEmbedder(Builder builder) { - algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring()); + algorithm = new SentencePieceAlgorithm(builder.getCollapseUnknowns(), builder.getScoring()); models = builder.getModels().entrySet() .stream() @@ -94,9 +94,6 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small * it will be truncated.</p> * - * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token - * position as value.</p> - * * <p>If the tensor is any other type IllegalArgumentException is thrown.</p> * * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. @@ -105,40 +102,15 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { */ @Override public Tensor embed(String rawInput, Embedder.Context context, TensorType type) { - if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { - // Build to a list first since we can't reverse a tensor builder - List<Integer> values = embed(rawInput, context); - - long maxSize = values.size(); - if (type.dimensions().get(0).size().isPresent()) - maxSize = Math.min(maxSize, type.dimensions().get(0).size().get()); - - Tensor.Builder builder = Tensor.Builder.of(type); - for (int i = 0; i < maxSize; i++) - builder.cell(values.get(i), i); - return builder.build(); - } - else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) { - // Build to a list first since we can't reverse a tensor builder - List<String> values = segment(rawInput, context.getLanguage()); - - Tensor.Builder builder = Tensor.Builder.of(type); - for (int i = 0; i < values.size(); i++) - builder.cell(TensorAddress.ofLabels(values.get(i)), i); - return builder.build(); - } - else { - throw new IllegalArgumentException("Don't know how to embed with SentencePiece into " + type); - } + return Embed.asTensor(rawInput, this, context, type); } private <RESULTTYPE> void segment(String input, Language language, ResultBuilder<RESULTTYPE> resultBuilder) { - Model model = resolveFrom(language); - algorithm.segment(input, resultBuilder, model); + algorithm.segment(input, resultBuilder, resolveModelFrom(language)); } - private Model resolveFrom(Language language) { + private Model resolveModelFrom(Language language) { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); if (models.containsKey(language)) return models.get(language); @@ -166,7 +138,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { public static class Builder { - private final Map<Language, Path> models = new HashMap<>(); + private final Map<Language, Path> models = new EnumMap<>(Language.class); private boolean collapseUnknowns = true; private Scoring scoring = Scoring.fewestSegments; @@ -177,9 +149,8 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { collapseUnknowns = config.collapseUnknowns(); scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments : Scoring.highestScore; - for (SentencePieceConfig.Model model : config.model()) { + for (SentencePieceConfig.Model model : config.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); - } } public void addModel(Language language, Path model) { diff --git a/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java b/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java new file mode 100644 index 00000000000..401347cc452 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java @@ -0,0 +1,43 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.tools; + +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.List; + +/** + * Component internal helpers for embedding + * + * @author bratseth + */ +public class Embed { + + /** + * Convenience function which embeds the given string into the given tensor type (if possible), + * using the given embedder. + */ + public static Tensor asTensor(String text, + Embedder embedder, + Embedder.Context context, + TensorType type) { + if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { + // Build to a list first since we can't reverse a tensor builder + List<Integer> values = embedder.embed(text, context); + + long maxSize = values.size(); + if (type.dimensions().get(0).size().isPresent()) + maxSize = Math.min(maxSize, type.dimensions().get(0).size().get()); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int i = 0; i < maxSize; i++) + builder.cell(values.get(i), i); + return builder.build(); + } + else { + throw new IllegalArgumentException("Don't know how to embed into " + type); + } + } + +} |