aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-12-16 18:35:11 +0100
committerJon Bratseth <bratseth@gmail.com>2021-12-16 18:35:11 +0100
commit767cb63af0f530605180f5438767406e1db27520 (patch)
treec0ea9e8ec4ded2dea6064a45334e6f8a1408f7b8 /linguistics-components/src/main/java/com
parent1eefb9854bcd7ca264889239b32e7fc8c8830672 (diff)
Add a BERT embedder
Diffstat (limited to 'linguistics-components/src/main/java/com')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java131
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/bert/Model.java105
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java7
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java45
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java43
5 files changed, 294 insertions, 37 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java
new file mode 100644
index 00000000000..c2b19391e74
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java
@@ -0,0 +1,131 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.bert;
+
+import com.google.inject.Inject;
+import com.yahoo.language.tools.Embed;
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.language.process.Segmenter;
+import com.yahoo.language.process.Tokenizer;
+import com.yahoo.language.simple.SimpleLinguistics;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.file.Path;
+import java.util.EnumMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * An embedder to use with BERT models: Text is tokenized into tokens from a configured vocabulary,
+ * and optionally returned as a 1-d dense tensor of token ids.
+ *
+ * @author bratseth
+ */
+public class BertEmbedder implements Embedder, Segmenter {
+
+ private final Map<Language, Model> models;
+
+ private final Tokenizer tokenizer;
+
+ @Inject
+ public BertEmbedder(BertConfig config) {
+ this(new Builder(config));
+ }
+
+ private BertEmbedder(Builder builder) {
+ super();
+ this.tokenizer = new SimpleLinguistics().getTokenizer(); // always just split on spaces etc. and lowercase
+ models = builder.getModels().entrySet()
+ .stream()
+ .map(e -> new Model(e.getKey(), e.getValue()))
+ .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
+ if (models.isEmpty())
+ throw new IllegalArgumentException("BertEmbedder requires at least one model configured");
+ }
+
+ /**
+ * Segments the given text into token segments from the BERT vocabulary.
+ *
+ * @param text the text to segment. The text should be of a language using space-separated words.
+ * @return the list of zero or more token ids resulting from segmenting the input text
+ */
+ @Override
+ public List<String> segment(String text, Language language) {
+ return resolveModelFrom(language).segment(text, tokenizer);
+ }
+
+ /**
+ * Segments the given text into token segments from the BERT vocabulary and returns the token ids.
+ *
+ * @param text the text to segment. The text should be of a language using space-separated words.
+ * @param context the context which specifies the language used to select a model
+ * @return the list of zero or more token ids resulting from segmenting the input text
+ */
+ @Override
+ public List<Integer> embed(String text, Context context) {
+ return resolveModelFrom(context.getLanguage()).embed(text, tokenizer);
+ }
+
+ /**
+ * <p>Embeds text into a tensor.</p>
+ *
+ * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order
+ * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small
+ * it will be truncated.</p>
+ *
+ * <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
+ *
+ * @param text the text to segment. The text should be of a language using space-separated words.
+ * @param context the context which specifies the language used to select a model
+ * @return the list of zero or more token ids resulting from segmenting the input text
+ */
+ @Override
+ public Tensor embed(String text, Context context, TensorType type) {
+ return Embed.asTensor(text, this, context, type);
+ }
+
+ private Model resolveModelFrom(Language language) {
+ // Disregard language if there is default model
+ if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
+ if (models.containsKey(language)) return models.get(language);
+ throw new IllegalArgumentException("No BERT model for language " + language + " is configured");
+ }
+
+ public static class Builder {
+
+ private final Map<Language, Path> models = new EnumMap<>(Language.class);
+
+ public Builder() {
+ }
+
+ private Builder(BertConfig config) {
+ for (BertConfig.Model model : config.model())
+ addModel(Language.fromLanguageTag(model.language()), model.path());
+ }
+
+ public void addModel(Language language, Path model) {
+ models.put(language, model);
+ }
+
+ /**
+ * Adds the model that will be used if the language is unknown, OR only one model is specified.
+ * The same as addModel(Language.UNKNOWN, model).
+ */
+ public BertEmbedder.Builder addDefaultModel(Path model) {
+ addModel(Language.UNKNOWN, model);
+ return this;
+ }
+
+ public Map<Language, Path> getModels() { return models; }
+
+ public BertEmbedder build() {
+ if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied");
+ return new BertEmbedder(this);
+ }
+
+ }
+
+}
+
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java b/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java
new file mode 100644
index 00000000000..54f37d597ce
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java
@@ -0,0 +1,105 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.bert;
+
+import com.yahoo.collections.Tuple2;
+import com.yahoo.language.Language;
+import com.yahoo.language.process.StemMode;
+import com.yahoo.language.process.Token;
+import com.yahoo.language.process.Tokenizer;
+
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableMap;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+
+/**
+ * A BERT embedder "model" - just a vocabulary of strings with a fixed id (index).
+ *
+ * Adapted from
+ * https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java
+ * licensed under the Apache License, Version 2.0
+ *
+ * @author bergum
+ * @author bratseth
+ */
+class Model {
+
+ final Path source;
+ final Language language;
+ private final NavigableMap<String, Integer> vocabulary;
+ private final Map<Integer, String> tokenId2Token;
+
+ Model(Language language, Path path) {
+ this.source = path;
+ this.language = language;
+
+ this.vocabulary = new TreeMap<>(Collections.reverseOrder());
+ this.tokenId2Token = new HashMap<>();
+ try (BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(path.toFile()),
+ StandardCharsets.UTF_8))) {
+ String token;
+ int i = 0;
+ while ((token = reader.readLine()) != null) {
+ this.vocabulary.put(token, i);
+ this.tokenId2Token.put(i, token);
+ i++;
+ }
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read a BERT model from " + path, e);
+ }
+
+ }
+
+ public List<Integer> embed(String text, Tokenizer tokenizer) {
+ List<Integer> ids = new ArrayList<>();
+ text = text.toLowerCase();
+ for (Token t : tokenizer.tokenize(text, language, StemMode.NONE, true)) {
+ String originalToken = t.getTokenString();
+ String candidate = originalToken;
+ int count = 0;
+ while (candidate.length() > 0 && !"##".equals(candidate)) {
+ Tuple2<String, Integer> entry = findLongestSubstring(candidate);
+ if (entry == null) break;
+ ids.add(entry.second);
+ candidate = "##" + candidate.substring(entry.first.length());
+ if (count++ > originalToken.length()) break;
+ }
+ }
+
+ return ids;
+ }
+
+ public List<String> segment(String text, Tokenizer tokenizer) {
+ return embed(text, tokenizer).stream().map(tokenId -> tokenId2Token.get(tokenId)).collect(Collectors.toList());
+ }
+
+ private Tuple2<String, Integer> findLongestSubstring(String candidate) {
+ NavigableMap<String, Integer> tailMap = this.vocabulary.tailMap(candidate, true);
+ if (tailMap.isEmpty())
+ return null;
+ String longestSubstring = tailMap.firstKey();
+ Integer id = tailMap.firstEntry().getValue();
+ int subStringLength = Math.min(candidate.length(), longestSubstring.length());
+ while (!candidate.startsWith(longestSubstring)) {
+ subStringLength--;
+ tailMap = tailMap.tailMap(candidate.substring(0, subStringLength), true);
+ if (tailMap.isEmpty())
+ return null;
+ longestSubstring = tailMap.firstKey();
+ id = tailMap.firstEntry().getValue();
+ }
+ return new Tuple2<>(longestSubstring, id);
+ }
+
+}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java
new file mode 100644
index 00000000000..e3f612f4114
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java
@@ -0,0 +1,7 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package com.yahoo.language.bert;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index 3afc85300d4..ff7f4ae42bc 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -3,17 +3,17 @@ package com.yahoo.language.sentencepiece;
import com.yahoo.api.annotations.Beta;
import com.google.inject.Inject;
+import com.yahoo.language.tools.Embed;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.HashMap;
+import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -38,7 +38,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
}
public SentencePieceEmbedder(Builder builder) {
- algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring());
+ algorithm = new SentencePieceAlgorithm(builder.getCollapseUnknowns(), builder.getScoring());
models = builder.getModels().entrySet()
.stream()
@@ -94,9 +94,6 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
* they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small
* it will be truncated.</p>
*
- * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token
- * position as value.</p>
- *
* <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
*
* @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
@@ -105,40 +102,15 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
*/
@Override
public Tensor embed(String rawInput, Embedder.Context context, TensorType type) {
- if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
- // Build to a list first since we can't reverse a tensor builder
- List<Integer> values = embed(rawInput, context);
-
- long maxSize = values.size();
- if (type.dimensions().get(0).size().isPresent())
- maxSize = Math.min(maxSize, type.dimensions().get(0).size().get());
-
- Tensor.Builder builder = Tensor.Builder.of(type);
- for (int i = 0; i < maxSize; i++)
- builder.cell(values.get(i), i);
- return builder.build();
- }
- else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) {
- // Build to a list first since we can't reverse a tensor builder
- List<String> values = segment(rawInput, context.getLanguage());
-
- Tensor.Builder builder = Tensor.Builder.of(type);
- for (int i = 0; i < values.size(); i++)
- builder.cell(TensorAddress.ofLabels(values.get(i)), i);
- return builder.build();
- }
- else {
- throw new IllegalArgumentException("Don't know how to embed with SentencePiece into " + type);
- }
+ return Embed.asTensor(rawInput, this, context, type);
}
private <RESULTTYPE> void segment(String input, Language language,
ResultBuilder<RESULTTYPE> resultBuilder) {
- Model model = resolveFrom(language);
- algorithm.segment(input, resultBuilder, model);
+ algorithm.segment(input, resultBuilder, resolveModelFrom(language));
}
- private Model resolveFrom(Language language) {
+ private Model resolveModelFrom(Language language) {
// Disregard language if there is default model
if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
if (models.containsKey(language)) return models.get(language);
@@ -166,7 +138,7 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
public static class Builder {
- private final Map<Language, Path> models = new HashMap<>();
+ private final Map<Language, Path> models = new EnumMap<>(Language.class);
private boolean collapseUnknowns = true;
private Scoring scoring = Scoring.fewestSegments;
@@ -177,9 +149,8 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
collapseUnknowns = config.collapseUnknowns();
scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments
: Scoring.highestScore;
- for (SentencePieceConfig.Model model : config.model()) {
+ for (SentencePieceConfig.Model model : config.model())
addModel(Language.fromLanguageTag(model.language()), model.path());
- }
}
public void addModel(Language language, Path model) {
diff --git a/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java b/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java
new file mode 100644
index 00000000000..401347cc452
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/tools/Embed.java
@@ -0,0 +1,43 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.tools;
+
+import com.yahoo.language.process.Embedder;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.List;
+
+/**
+ * Component internal helpers for embedding
+ *
+ * @author bratseth
+ */
+public class Embed {
+
+ /**
+ * Convenience function which embeds the given string into the given tensor type (if possible),
+ * using the given embedder.
+ */
+ public static Tensor asTensor(String text,
+ Embedder embedder,
+ Embedder.Context context,
+ TensorType type) {
+ if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
+ // Build to a list first since we can't reverse a tensor builder
+ List<Integer> values = embedder.embed(text, context);
+
+ long maxSize = values.size();
+ if (type.dimensions().get(0).size().isPresent())
+ maxSize = Math.min(maxSize, type.dimensions().get(0).size().get());
+
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ for (int i = 0; i < maxSize; i++)
+ builder.cell(values.get(i), i);
+ return builder.build();
+ }
+ else {
+ throw new IllegalArgumentException("Don't know how to embed into " + type);
+ }
+ }
+
+}