summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java8
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java (renamed from linguistics-components/src/main/java/com/yahoo/language/bert/Model.java)29
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java (renamed from linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java)52
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java (renamed from linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java)2
4 files changed, 60 insertions, 31 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index ff7f4ae42bc..31964eac514 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -10,6 +10,7 @@ import com.yahoo.language.process.Segmenter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
@@ -136,13 +137,16 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
return b.toString();
}
- public static class Builder {
+ public static final class Builder {
private final Map<Language, Path> models = new EnumMap<>(Language.class);
private boolean collapseUnknowns = true;
private Scoring scoring = Scoring.fewestSegments;
- public Builder() {
+ public Builder() {}
+
+ public Builder(String defaultModelFile) {
+ addDefaultModel(new File(defaultModelFile).toPath());
}
private Builder(SentencePieceConfig config) {
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java
index 54f37d597ce..ce996b85313 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java
@@ -1,5 +1,5 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.language.bert;
+package com.yahoo.language.wordpiece;
import com.yahoo.collections.Tuple2;
import com.yahoo.language.Language;
@@ -23,7 +23,7 @@ import java.util.TreeMap;
import java.util.stream.Collectors;
/**
- * A BERT embedder "model" - just a vocabulary of strings with a fixed id (index).
+ * A WordPiece embedder "model" - just a vocabulary of strings with a fixed id (index).
*
* Adapted from
* https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java
@@ -34,12 +34,14 @@ import java.util.stream.Collectors;
*/
class Model {
- final Path source;
- final Language language;
+ private final String subwordPrefix;
+ private final Path source;
+ private final Language language;
private final NavigableMap<String, Integer> vocabulary;
private final Map<Integer, String> tokenId2Token;
- Model(Language language, Path path) {
+ Model(String subwordPrefix, Language language, Path path) {
+ this.subwordPrefix = subwordPrefix;
this.source = path;
this.language = language;
@@ -56,23 +58,25 @@ class Model {
}
}
catch (IOException e) {
- throw new IllegalArgumentException("Could not read a BERT model from " + path, e);
+ throw new IllegalArgumentException("Could not read a WordPiece model from " + path, e);
}
}
- public List<Integer> embed(String text, Tokenizer tokenizer) {
+ Language language() { return language; }
+
+ List<Integer> embed(String text, Tokenizer tokenizer) {
List<Integer> ids = new ArrayList<>();
text = text.toLowerCase();
for (Token t : tokenizer.tokenize(text, language, StemMode.NONE, true)) {
String originalToken = t.getTokenString();
String candidate = originalToken;
int count = 0;
- while (candidate.length() > 0 && !"##".equals(candidate)) {
+ while (candidate.length() > 0 && !candidate.equals(subwordPrefix)) {
Tuple2<String, Integer> entry = findLongestSubstring(candidate);
if (entry == null) break;
ids.add(entry.second);
- candidate = "##" + candidate.substring(entry.first.length());
+ candidate = subwordPrefix + candidate.substring(entry.first.length());
if (count++ > originalToken.length()) break;
}
}
@@ -80,7 +84,7 @@ class Model {
return ids;
}
- public List<String> segment(String text, Tokenizer tokenizer) {
+ List<String> segment(String text, Tokenizer tokenizer) {
return embed(text, tokenizer).stream().map(tokenId -> tokenId2Token.get(tokenId)).collect(Collectors.toList());
}
@@ -102,4 +106,9 @@ class Model {
return new Tuple2<>(longestSubstring, id);
}
+ @Override
+ public String toString() {
+ return "WordPiece model for " + language + ": '" + source + "'";
+ }
+
}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java
index c2b19391e74..08de05f351a 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java
@@ -1,5 +1,5 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.language.bert;
+package com.yahoo.language.wordpiece;
import com.google.inject.Inject;
import com.yahoo.language.tools.Embed;
@@ -10,7 +10,9 @@ import com.yahoo.language.process.Tokenizer;
import com.yahoo.language.simple.SimpleLinguistics;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.language.wordpiece.WordPieceConfig;
+import java.io.File;
import java.nio.file.Path;
import java.util.EnumMap;
import java.util.List;
@@ -18,35 +20,37 @@ import java.util.Map;
import java.util.stream.Collectors;
/**
- * An embedder to use with BERT models: Text is tokenized into tokens from a configured vocabulary,
+ * An implementation of the WordPiece embedder, usually used with BERT models,
+ * see https://arxiv.org/pdf/1609.08144v2.pdf
+ * Text is tokenized into tokens from a configured vocabulary,
* and optionally returned as a 1-d dense tensor of token ids.
*
* @author bratseth
*/
-public class BertEmbedder implements Embedder, Segmenter {
+public class WordPieceEmbedder implements Embedder, Segmenter {
private final Map<Language, Model> models;
private final Tokenizer tokenizer;
@Inject
- public BertEmbedder(BertConfig config) {
+ public WordPieceEmbedder(WordPieceConfig config) {
this(new Builder(config));
}
- private BertEmbedder(Builder builder) {
+ private WordPieceEmbedder(Builder builder) {
super();
this.tokenizer = new SimpleLinguistics().getTokenizer(); // always just split on spaces etc. and lowercase
models = builder.getModels().entrySet()
.stream()
- .map(e -> new Model(e.getKey(), e.getValue()))
- .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
+ .map(e -> new Model(builder.getSubwordPrefix(), e.getKey(), e.getValue()))
+ .collect(Collectors.toUnmodifiableMap(m -> m.language(), m -> m));
if (models.isEmpty())
- throw new IllegalArgumentException("BertEmbedder requires at least one model configured");
+ throw new IllegalArgumentException("WordPieceEmbedder requires at least one model configured");
}
/**
- * Segments the given text into token segments from the BERT vocabulary.
+ * Segments the given text into token segments from the WordPiece vocabulary.
*
* @param text the text to segment. The text should be of a language using space-separated words.
* @return the list of zero or more token ids resulting from segmenting the input text
@@ -57,7 +61,7 @@ public class BertEmbedder implements Embedder, Segmenter {
}
/**
- * Segments the given text into token segments from the BERT vocabulary and returns the token ids.
+ * Segments the given text into token segments from the WordPiece vocabulary and returns the token ids.
*
* @param text the text to segment. The text should be of a language using space-separated words.
* @param context the context which specifies the language used to select a model
@@ -90,21 +94,33 @@ public class BertEmbedder implements Embedder, Segmenter {
// Disregard language if there is default model
if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
if (models.containsKey(language)) return models.get(language);
- throw new IllegalArgumentException("No BERT model for language " + language + " is configured");
+ throw new IllegalArgumentException("No WordPiece model for language " + language + " is configured");
}
- public static class Builder {
+ public static final class Builder {
+ private String subwordPrefix = "##";
private final Map<Language, Path> models = new EnumMap<>(Language.class);
- public Builder() {
+ public Builder() {}
+
+ public Builder(String defaultModelFile) {
+ addDefaultModel(new File(defaultModelFile).toPath());
}
- private Builder(BertConfig config) {
- for (BertConfig.Model model : config.model())
+ private Builder(WordPieceConfig config) {
+ this.subwordPrefix = config.subwordPrefix();
+ for (WordPieceConfig.Model model : config.model())
addModel(Language.fromLanguageTag(model.language()), model.path());
}
+ public Builder setSubwordPrefix(String prefix) {
+ this.subwordPrefix = subwordPrefix;
+ return this;
+ }
+
+ public String getSubwordPrefix() { return subwordPrefix; }
+
public void addModel(Language language, Path model) {
models.put(language, model);
}
@@ -113,16 +129,16 @@ public class BertEmbedder implements Embedder, Segmenter {
* Adds the model that will be used if the language is unknown, OR only one model is specified.
* The same as addModel(Language.UNKNOWN, model).
*/
- public BertEmbedder.Builder addDefaultModel(Path model) {
+ public WordPieceEmbedder.Builder addDefaultModel(Path model) {
addModel(Language.UNKNOWN, model);
return this;
}
public Map<Language, Path> getModels() { return models; }
- public BertEmbedder build() {
+ public WordPieceEmbedder build() {
if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied");
- return new BertEmbedder(this);
+ return new WordPieceEmbedder(this);
}
}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java
index e3f612f4114..0bbb6f001f5 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
@ExportPackage
@PublicApi
-package com.yahoo.language.bert;
+package com.yahoo.language.wordpiece;
import com.yahoo.api.annotations.PublicApi;
import com.yahoo.osgi.annotation.ExportPackage;