diff options
13 files changed, 306 insertions, 276 deletions
diff --git a/linguistics-components/abi-spec.json b/linguistics-components/abi-spec.json index 6dba8b602bd..39666fd93a3 100644 --- a/linguistics-components/abi-spec.json +++ b/linguistics-components/abi-spec.json @@ -1,5 +1,22 @@ { - "com.yahoo.language.bert.BertConfig$Builder": { + "com.yahoo.language.sentencepiece.Scoring": { + "superClass": "java.lang.Enum", + "interfaces": [], + "attributes": [ + "public", + "final", + "enum" + ], + "methods": [ + "public static com.yahoo.language.sentencepiece.Scoring[] values()", + "public static com.yahoo.language.sentencepiece.Scoring valueOf(java.lang.String)" + ], + "fields": [ + "public static final enum com.yahoo.language.sentencepiece.Scoring highestScore", + "public static final enum com.yahoo.language.sentencepiece.Scoring fewestSegments" + ] + }, + "com.yahoo.language.sentencepiece.SentencePieceConfig$Builder": { "superClass": "java.lang.Object", "interfaces": [ "com.yahoo.config.ConfigInstance$Builder" @@ -9,23 +26,25 @@ ], "methods": [ "public void <init>()", - "public void <init>(com.yahoo.language.bert.BertConfig)", - "public com.yahoo.language.bert.BertConfig$Builder model(com.yahoo.language.bert.BertConfig$Model$Builder)", - "public com.yahoo.language.bert.BertConfig$Builder model(java.util.function.Consumer)", - "public com.yahoo.language.bert.BertConfig$Builder model(java.util.List)", + "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder collapseUnknowns(boolean)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder scoring(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.function.Consumer)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.List)", "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", "public final java.lang.String getDefMd5()", "public final java.lang.String getDefName()", "public final java.lang.String getDefNamespace()", "public final boolean getApplyOnRestart()", "public final void setApplyOnRestart(boolean)", - "public com.yahoo.language.bert.BertConfig build()" + "public com.yahoo.language.sentencepiece.SentencePieceConfig build()" ], "fields": [ "public java.util.List model" ] }, - "com.yahoo.language.bert.BertConfig$Model$Builder": { + "com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder": { "superClass": "java.lang.Object", "interfaces": [ "com.yahoo.config.ConfigBuilder" @@ -35,14 +54,14 @@ ], "methods": [ "public void <init>()", - "public void <init>(com.yahoo.language.bert.BertConfig$Model)", - "public com.yahoo.language.bert.BertConfig$Model$Builder language(java.lang.String)", - "public com.yahoo.language.bert.BertConfig$Model$Builder path(com.yahoo.config.FileReference)", - "public com.yahoo.language.bert.BertConfig$Model build()" + "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder language(java.lang.String)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder path(com.yahoo.config.FileReference)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model build()" ], "fields": [] }, - "com.yahoo.language.bert.BertConfig$Model": { + "com.yahoo.language.sentencepiece.SentencePieceConfig$Model": { "superClass": "com.yahoo.config.InnerNode", "interfaces": [], "attributes": [ @@ -50,13 +69,13 @@ "final" ], "methods": [ - "public void <init>(com.yahoo.language.bert.BertConfig$Model$Builder)", + "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)", "public java.lang.String language()", "public java.nio.file.Path path()" ], "fields": [] }, - "com.yahoo.language.bert.BertConfig$Producer": { + "com.yahoo.language.sentencepiece.SentencePieceConfig$Producer": { "superClass": "java.lang.Object", "interfaces": [ "com.yahoo.config.ConfigInstance$Producer" @@ -67,11 +86,44 @@ "abstract" ], "methods": [ - "public abstract void getConfig(com.yahoo.language.bert.BertConfig$Builder)" + "public abstract void getConfig(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)" ], "fields": [] }, - "com.yahoo.language.bert.BertConfig": { + "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum": { + "superClass": "java.lang.Enum", + "interfaces": [], + "attributes": [ + "public", + "final", + "enum" + ], + "methods": [ + "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum[] values()", + "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum valueOf(java.lang.String)" + ], + "fields": [ + "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore", + "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments" + ] + }, + "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring": { + "superClass": "com.yahoo.config.EnumNode", + "interfaces": [], + "attributes": [ + "public", + "final" + ], + "methods": [ + "public void <init>()", + "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)" + ], + "fields": [ + "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore", + "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments" + ] + }, + "com.yahoo.language.sentencepiece.SentencePieceConfig": { "superClass": "com.yahoo.config.ConfigInstance", "interfaces": [], "attributes": [ @@ -83,9 +135,11 @@ "public static java.lang.String getDefName()", "public static java.lang.String getDefNamespace()", "public static java.lang.String getDefVersion()", - "public void <init>(com.yahoo.language.bert.BertConfig$Builder)", + "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)", + "public boolean collapseUnknowns()", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum scoring()", "public java.util.List model()", - "public com.yahoo.language.bert.BertConfig$Model model(int)" + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model model(int)" ], "fields": [ "public static final java.lang.String CONFIG_DEF_MD5", @@ -95,56 +149,47 @@ "public static final java.lang.String[] CONFIG_DEF_SCHEMA" ] }, - "com.yahoo.language.bert.BertEmbedder$Builder": { + "com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder": { "superClass": "java.lang.Object", "interfaces": [], "attributes": [ - "public" + "public", + "final" ], "methods": [ "public void <init>()", + "public void <init>(java.lang.String)", "public void addModel(com.yahoo.language.Language, java.nio.file.Path)", - "public com.yahoo.language.bert.BertEmbedder$Builder addDefaultModel(java.nio.file.Path)", + "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder addDefaultModel(java.nio.file.Path)", "public java.util.Map getModels()", - "public com.yahoo.language.bert.BertEmbedder build()" + "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setCollapseUnknowns(boolean)", + "public boolean getCollapseUnknowns()", + "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setScoring(com.yahoo.language.sentencepiece.Scoring)", + "public com.yahoo.language.sentencepiece.Scoring getScoring()", + "public com.yahoo.language.sentencepiece.SentencePieceEmbedder build()" ], "fields": [] }, - "com.yahoo.language.bert.BertEmbedder": { + "com.yahoo.language.sentencepiece.SentencePieceEmbedder": { "superClass": "java.lang.Object", "interfaces": [ - "com.yahoo.language.process.Embedder", - "com.yahoo.language.process.Segmenter" + "com.yahoo.language.process.Segmenter", + "com.yahoo.language.process.Embedder" ], "attributes": [ "public" ], "methods": [ - "public void <init>(com.yahoo.language.bert.BertConfig)", + "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)", + "public void <init>(com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder)", "public java.util.List segment(java.lang.String, com.yahoo.language.Language)", "public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)", - "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)" + "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)", + "public java.lang.String normalize(java.lang.String)" ], "fields": [] }, - "com.yahoo.language.sentencepiece.Scoring": { - "superClass": "java.lang.Enum", - "interfaces": [], - "attributes": [ - "public", - "final", - "enum" - ], - "methods": [ - "public static com.yahoo.language.sentencepiece.Scoring[] values()", - "public static com.yahoo.language.sentencepiece.Scoring valueOf(java.lang.String)" - ], - "fields": [ - "public static final enum com.yahoo.language.sentencepiece.Scoring highestScore", - "public static final enum com.yahoo.language.sentencepiece.Scoring fewestSegments" - ] - }, - "com.yahoo.language.sentencepiece.SentencePieceConfig$Builder": { + "com.yahoo.language.wordpiece.WordPieceConfig$Builder": { "superClass": "java.lang.Object", "interfaces": [ "com.yahoo.config.ConfigInstance$Builder" @@ -154,25 +199,24 @@ ], "methods": [ "public void <init>()", - "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder collapseUnknowns(boolean)", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder scoring(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.function.Consumer)", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.List)", + "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig)", + "public com.yahoo.language.wordpiece.WordPieceConfig$Builder subwordPrefix(java.lang.String)", + "public com.yahoo.language.wordpiece.WordPieceConfig$Builder model(com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder)", + "public com.yahoo.language.wordpiece.WordPieceConfig$Builder model(java.util.function.Consumer)", + "public com.yahoo.language.wordpiece.WordPieceConfig$Builder model(java.util.List)", "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", "public final java.lang.String getDefMd5()", "public final java.lang.String getDefName()", "public final java.lang.String getDefNamespace()", "public final boolean getApplyOnRestart()", "public final void setApplyOnRestart(boolean)", - "public com.yahoo.language.sentencepiece.SentencePieceConfig build()" + "public com.yahoo.language.wordpiece.WordPieceConfig build()" ], "fields": [ "public java.util.List model" ] }, - "com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder": { + "com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder": { "superClass": "java.lang.Object", "interfaces": [ "com.yahoo.config.ConfigBuilder" @@ -182,14 +226,14 @@ ], "methods": [ "public void <init>()", - "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model)", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder language(java.lang.String)", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder path(com.yahoo.config.FileReference)", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model build()" + "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig$Model)", + "public com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder language(java.lang.String)", + "public com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder path(com.yahoo.config.FileReference)", + "public com.yahoo.language.wordpiece.WordPieceConfig$Model build()" ], "fields": [] }, - "com.yahoo.language.sentencepiece.SentencePieceConfig$Model": { + "com.yahoo.language.wordpiece.WordPieceConfig$Model": { "superClass": "com.yahoo.config.InnerNode", "interfaces": [], "attributes": [ @@ -197,13 +241,13 @@ "final" ], "methods": [ - "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)", + "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder)", "public java.lang.String language()", "public java.nio.file.Path path()" ], "fields": [] }, - "com.yahoo.language.sentencepiece.SentencePieceConfig$Producer": { + "com.yahoo.language.wordpiece.WordPieceConfig$Producer": { "superClass": "java.lang.Object", "interfaces": [ "com.yahoo.config.ConfigInstance$Producer" @@ -214,44 +258,11 @@ "abstract" ], "methods": [ - "public abstract void getConfig(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)" + "public abstract void getConfig(com.yahoo.language.wordpiece.WordPieceConfig$Builder)" ], "fields": [] }, - "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum": { - "superClass": "java.lang.Enum", - "interfaces": [], - "attributes": [ - "public", - "final", - "enum" - ], - "methods": [ - "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum[] values()", - "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum valueOf(java.lang.String)" - ], - "fields": [ - "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore", - "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments" - ] - }, - "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring": { - "superClass": "com.yahoo.config.EnumNode", - "interfaces": [], - "attributes": [ - "public", - "final" - ], - "methods": [ - "public void <init>()", - "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)" - ], - "fields": [ - "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore", - "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments" - ] - }, - "com.yahoo.language.sentencepiece.SentencePieceConfig": { + "com.yahoo.language.wordpiece.WordPieceConfig": { "superClass": "com.yahoo.config.ConfigInstance", "interfaces": [], "attributes": [ @@ -263,11 +274,10 @@ "public static java.lang.String getDefName()", "public static java.lang.String getDefNamespace()", "public static java.lang.String getDefVersion()", - "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)", - "public boolean collapseUnknowns()", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum scoring()", + "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig$Builder)", + "public java.lang.String subwordPrefix()", "public java.util.List model()", - "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model model(int)" + "public com.yahoo.language.wordpiece.WordPieceConfig$Model model(int)" ], "fields": [ "public static final java.lang.String CONFIG_DEF_MD5", @@ -277,41 +287,39 @@ "public static final java.lang.String[] CONFIG_DEF_SCHEMA" ] }, - "com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder": { + "com.yahoo.language.wordpiece.WordPieceEmbedder$Builder": { "superClass": "java.lang.Object", "interfaces": [], "attributes": [ - "public" + "public", + "final" ], "methods": [ "public void <init>()", + "public void <init>(java.lang.String)", + "public com.yahoo.language.wordpiece.WordPieceEmbedder$Builder setSubwordPrefix(java.lang.String)", + "public java.lang.String getSubwordPrefix()", "public void addModel(com.yahoo.language.Language, java.nio.file.Path)", - "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder addDefaultModel(java.nio.file.Path)", + "public com.yahoo.language.wordpiece.WordPieceEmbedder$Builder addDefaultModel(java.nio.file.Path)", "public java.util.Map getModels()", - "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setCollapseUnknowns(boolean)", - "public boolean getCollapseUnknowns()", - "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setScoring(com.yahoo.language.sentencepiece.Scoring)", - "public com.yahoo.language.sentencepiece.Scoring getScoring()", - "public com.yahoo.language.sentencepiece.SentencePieceEmbedder build()" + "public com.yahoo.language.wordpiece.WordPieceEmbedder build()" ], "fields": [] }, - "com.yahoo.language.sentencepiece.SentencePieceEmbedder": { + "com.yahoo.language.wordpiece.WordPieceEmbedder": { "superClass": "java.lang.Object", "interfaces": [ - "com.yahoo.language.process.Segmenter", - "com.yahoo.language.process.Embedder" + "com.yahoo.language.process.Embedder", + "com.yahoo.language.process.Segmenter" ], "attributes": [ "public" ], "methods": [ - "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)", - "public void <init>(com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder)", + "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig)", "public java.util.List segment(java.lang.String, com.yahoo.language.Language)", "public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)", - "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)", - "public java.lang.String normalize(java.lang.String)" + "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)" ], "fields": [] } diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index ff7f4ae42bc..31964eac514 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -10,6 +10,7 @@ import com.yahoo.language.process.Segmenter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.io.File; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; @@ -136,13 +137,16 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { return b.toString(); } - public static class Builder { + public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); private boolean collapseUnknowns = true; private Scoring scoring = Scoring.fewestSegments; - public Builder() { + public Builder() {} + + public Builder(String defaultModelFile) { + addDefaultModel(new File(defaultModelFile).toPath()); } private Builder(SentencePieceConfig config) { diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java index 54f37d597ce..ce996b85313 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java @@ -1,5 +1,5 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.bert; +package com.yahoo.language.wordpiece; import com.yahoo.collections.Tuple2; import com.yahoo.language.Language; @@ -23,7 +23,7 @@ import java.util.TreeMap; import java.util.stream.Collectors; /** - * A BERT embedder "model" - just a vocabulary of strings with a fixed id (index). + * A WordPiece embedder "model" - just a vocabulary of strings with a fixed id (index). * * Adapted from * https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java @@ -34,12 +34,14 @@ import java.util.stream.Collectors; */ class Model { - final Path source; - final Language language; + private final String subwordPrefix; + private final Path source; + private final Language language; private final NavigableMap<String, Integer> vocabulary; private final Map<Integer, String> tokenId2Token; - Model(Language language, Path path) { + Model(String subwordPrefix, Language language, Path path) { + this.subwordPrefix = subwordPrefix; this.source = path; this.language = language; @@ -56,23 +58,25 @@ class Model { } } catch (IOException e) { - throw new IllegalArgumentException("Could not read a BERT model from " + path, e); + throw new IllegalArgumentException("Could not read a WordPiece model from " + path, e); } } - public List<Integer> embed(String text, Tokenizer tokenizer) { + Language language() { return language; } + + List<Integer> embed(String text, Tokenizer tokenizer) { List<Integer> ids = new ArrayList<>(); text = text.toLowerCase(); for (Token t : tokenizer.tokenize(text, language, StemMode.NONE, true)) { String originalToken = t.getTokenString(); String candidate = originalToken; int count = 0; - while (candidate.length() > 0 && !"##".equals(candidate)) { + while (candidate.length() > 0 && !candidate.equals(subwordPrefix)) { Tuple2<String, Integer> entry = findLongestSubstring(candidate); if (entry == null) break; ids.add(entry.second); - candidate = "##" + candidate.substring(entry.first.length()); + candidate = subwordPrefix + candidate.substring(entry.first.length()); if (count++ > originalToken.length()) break; } } @@ -80,7 +84,7 @@ class Model { return ids; } - public List<String> segment(String text, Tokenizer tokenizer) { + List<String> segment(String text, Tokenizer tokenizer) { return embed(text, tokenizer).stream().map(tokenId -> tokenId2Token.get(tokenId)).collect(Collectors.toList()); } @@ -102,4 +106,9 @@ class Model { return new Tuple2<>(longestSubstring, id); } + @Override + public String toString() { + return "WordPiece model for " + language + ": '" + source + "'"; + } + } diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java index c2b19391e74..08de05f351a 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java @@ -1,5 +1,5 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.bert; +package com.yahoo.language.wordpiece; import com.google.inject.Inject; import com.yahoo.language.tools.Embed; @@ -10,7 +10,9 @@ import com.yahoo.language.process.Tokenizer; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.language.wordpiece.WordPieceConfig; +import java.io.File; import java.nio.file.Path; import java.util.EnumMap; import java.util.List; @@ -18,35 +20,37 @@ import java.util.Map; import java.util.stream.Collectors; /** - * An embedder to use with BERT models: Text is tokenized into tokens from a configured vocabulary, + * An implementation of the WordPiece embedder, usually used with BERT models, + * see https://arxiv.org/pdf/1609.08144v2.pdf + * Text is tokenized into tokens from a configured vocabulary, * and optionally returned as a 1-d dense tensor of token ids. * * @author bratseth */ -public class BertEmbedder implements Embedder, Segmenter { +public class WordPieceEmbedder implements Embedder, Segmenter { private final Map<Language, Model> models; private final Tokenizer tokenizer; @Inject - public BertEmbedder(BertConfig config) { + public WordPieceEmbedder(WordPieceConfig config) { this(new Builder(config)); } - private BertEmbedder(Builder builder) { + private WordPieceEmbedder(Builder builder) { super(); this.tokenizer = new SimpleLinguistics().getTokenizer(); // always just split on spaces etc. and lowercase models = builder.getModels().entrySet() .stream() - .map(e -> new Model(e.getKey(), e.getValue())) - .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); + .map(e -> new Model(builder.getSubwordPrefix(), e.getKey(), e.getValue())) + .collect(Collectors.toUnmodifiableMap(m -> m.language(), m -> m)); if (models.isEmpty()) - throw new IllegalArgumentException("BertEmbedder requires at least one model configured"); + throw new IllegalArgumentException("WordPieceEmbedder requires at least one model configured"); } /** - * Segments the given text into token segments from the BERT vocabulary. + * Segments the given text into token segments from the WordPiece vocabulary. * * @param text the text to segment. The text should be of a language using space-separated words. * @return the list of zero or more token ids resulting from segmenting the input text @@ -57,7 +61,7 @@ public class BertEmbedder implements Embedder, Segmenter { } /** - * Segments the given text into token segments from the BERT vocabulary and returns the token ids. + * Segments the given text into token segments from the WordPiece vocabulary and returns the token ids. * * @param text the text to segment. The text should be of a language using space-separated words. * @param context the context which specifies the language used to select a model @@ -90,21 +94,33 @@ public class BertEmbedder implements Embedder, Segmenter { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); if (models.containsKey(language)) return models.get(language); - throw new IllegalArgumentException("No BERT model for language " + language + " is configured"); + throw new IllegalArgumentException("No WordPiece model for language " + language + " is configured"); } - public static class Builder { + public static final class Builder { + private String subwordPrefix = "##"; private final Map<Language, Path> models = new EnumMap<>(Language.class); - public Builder() { + public Builder() {} + + public Builder(String defaultModelFile) { + addDefaultModel(new File(defaultModelFile).toPath()); } - private Builder(BertConfig config) { - for (BertConfig.Model model : config.model()) + private Builder(WordPieceConfig config) { + this.subwordPrefix = config.subwordPrefix(); + for (WordPieceConfig.Model model : config.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); } + public Builder setSubwordPrefix(String prefix) { + this.subwordPrefix = subwordPrefix; + return this; + } + + public String getSubwordPrefix() { return subwordPrefix; } + public void addModel(Language language, Path model) { models.put(language, model); } @@ -113,16 +129,16 @@ public class BertEmbedder implements Embedder, Segmenter { * Adds the model that will be used if the language is unknown, OR only one model is specified. * The same as addModel(Language.UNKNOWN, model). */ - public BertEmbedder.Builder addDefaultModel(Path model) { + public WordPieceEmbedder.Builder addDefaultModel(Path model) { addModel(Language.UNKNOWN, model); return this; } public Map<Language, Path> getModels() { return models; } - public BertEmbedder build() { + public WordPieceEmbedder build() { if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); - return new BertEmbedder(this); + return new WordPieceEmbedder(this); } } diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java index e3f612f4114..0bbb6f001f5 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage @PublicApi -package com.yahoo.language.bert; +package com.yahoo.language.wordpiece; import com.yahoo.api.annotations.PublicApi; import com.yahoo.osgi.annotation.ExportPackage; diff --git a/linguistics-components/src/main/resources/configdefinitions/language.bert.bert.def b/linguistics-components/src/main/resources/configdefinitions/language.wordpiece.word-piece.def index 86d338758d0..08592250eb5 100644 --- a/linguistics-components/src/main/resources/configdefinitions/language.bert.bert.def +++ b/linguistics-components/src/main/resources/configdefinitions/language.wordpiece.word-piece.def @@ -1,8 +1,11 @@ # Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -# Configures com.yahoo.language.bert.BertEmbedder +# Configures com.yahoo.language.wordpiece.WordPieceEmbedder -namespace=language.bert +namespace=language.wordpiece + +# The prefix to prepend to subword tokens +subwordPrefix string default="##" # The language a model is for, one of the language tags in com.yahoo.language.Language. # Use "unknown" for a model to be used for any language (i.e by default). diff --git a/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java b/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java deleted file mode 100644 index 1bc25e0d217..00000000000 --- a/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.bert; - -import com.yahoo.config.FileReference; -import com.yahoo.language.Language; -import com.yahoo.language.process.Embedder; -import com.yahoo.language.simple.SimpleLinguistics; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import org.junit.Test; - -import java.io.File; -import java.util.List; - -import static org.junit.Assert.assertEquals; - -/** - * Tests the BERT embedder - * - * @author bratseth - */ -public class BertEmbedderTest { - - private static final String vocabulary = "src/test/models/bert/bert-base-uncased-vocab.txt"; - - @Test - public void testBertEmbedder() { - var embedder = new BertEmbedder.Builder().addDefaultModel(new File(vocabulary).toPath()).build(); - var expectedTokenIds = List.of(2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622); - assertEquals(expectedTokenIds, embedder.embed("what was the impact of the manhattan project", - new Embedder.Context("destination"))); - - var expectedTokens = List.of("what", "was", "the", "impact", "of", "the", "manhattan", "project"); - assertEquals(expectedTokens, embedder.segment("what was the impact of the manhattan project", - Language.ENGLISH)); - - var expectedDenseTensor = Tensor.from("tensor(x[8]):" + expectedTokenIds); - assertEquals(expectedDenseTensor, embedder.embed("what was the impact of the manhattan project", - new Embedder.Context("destination"), - expectedDenseTensor.type())); - } - - @Test - public void testBertEmbedderConfiguration() { - var config = new BertConfig.Builder().model(new BertConfig.Model.Builder().language("unknown") - .path(new FileReference(vocabulary))) - .build(); - var embedder = new BertEmbedder(config); - var expectedTokenIds = List.of(2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622); - assertEquals(expectedTokenIds, embedder.embed("what was the impact of the manhattan project", - new Embedder.Context("destination"))); - } - -} diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java index 1ed2271f774..19cb2267655 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java @@ -4,6 +4,7 @@ package com.yahoo.language.sentencepiece; import com.yahoo.config.FileReference; import com.yahoo.language.Language; +import com.yahoo.language.tools.EmbedderTester; import org.junit.Test; /** @@ -15,7 +16,7 @@ public class SentencePieceConfigurationTest { public void testEnglishTokenization() { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); } @@ -25,7 +26,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); b.collapseUnknowns(false); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); } @@ -34,7 +35,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); b.scoring(SentencePieceConfig.Scoring.highestScore); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("hello", "▁h", "el", "lo"); } @@ -43,7 +44,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b); addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); - var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); + var tester = new EmbedderTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index 8b3e2988c43..2fbafb23485 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -3,6 +3,7 @@ package com.yahoo.language.sentencepiece; import com.yahoo.language.Language; +import com.yahoo.language.tools.EmbedderTester; import org.junit.Test; import java.io.File; @@ -13,8 +14,8 @@ import java.io.File; public class SentencePieceTest { @Test - public void testEnglishTokenization() { - var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + public void testEnglishSegmenting() { + var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build()); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); tester.assertSegmented("hel", "▁hel"); @@ -36,33 +37,28 @@ public class SentencePieceTest { } @Test - public void testIntegerListEncoding() { - var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - tester.assertEmbedded("hello, world!", 908, 1418, 9934, 501, 9960); - tester.assertEmbedded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960); - } - - @Test - public void testDenseTensorEncoding() { - var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - tester.assertEmbedded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]"); - tester.assertEmbedded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]"); - tester.assertEmbedded("hello, world!", "tensor(d[2])", "[908,1418]"); + public void testEnglishEmbedding() { + var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build()); + tester.assertEmbedded("hello, world!", "tensor(d[10])", 908, 1418, 9934, 501, 9960); + tester.assertEmbedded("Hello, world!", "tensor(d[10])", 9912, 0, 6595, 9934, 501, 9960); + tester.assertEmbedded("hello, world!", "tensor(d[2])", 908, 1418, 9934, 501, 9960); } @Test public void testNoCollapse() { - var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder() - .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) - .setCollapseUnknowns(false)); + var builder = new SentencePieceEmbedder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setCollapseUnknowns(false); + var tester = new EmbedderTester(builder.build()); tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); } @Test public void testHighestScore() { - var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder() - .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) - .setScoring(Scoring.highestScore)); + var builder = new SentencePieceEmbedder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setScoring(Scoring.highestScore); + var tester = new EmbedderTester(builder.build()); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); tester.assertSegmented("hel", "▁h", "el"); @@ -74,7 +70,7 @@ public class SentencePieceTest { SentencePieceEmbedder.Builder builder = new SentencePieceEmbedder.Builder(); builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - var tester = new SentencePieceTester(builder); + var tester = new EmbedderTester(builder.build()); tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java deleted file mode 100644 index 4dae53c60df..00000000000 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -// - -package com.yahoo.language.sentencepiece; - -import com.yahoo.language.Language; -import com.yahoo.language.process.Embedder; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.nio.file.Path; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - -class SentencePieceTester { - - private final SentencePieceEmbedder embedder; - - public SentencePieceTester(Path model) { - this(new SentencePieceEmbedder.Builder().addDefaultModel(model)); - } - - public SentencePieceTester(SentencePieceEmbedder.Builder builder) { - this(builder.build()); - } - - public SentencePieceTester(SentencePieceEmbedder embedder) { - this.embedder = embedder; - } - - public void assertEmbedded(String input, Integer... expectedCodes) { - assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray()); - } - - public void assertEmbedded(String input, String tensorType, String tensor) { - TensorType type = TensorType.fromSpec(tensorType); - Tensor expected = Tensor.from(type, tensor); - assertEquals(expected, embedder.embed(input, new Embedder.Context("test"), type)); - } - - public void assertSegmented(String input, String... expectedSegments) { - assertSegmented(Language.UNKNOWN, input, expectedSegments); - } - - public void assertSegmented(Language language, String input, String... expectedSegments) { - assertArrayEquals(expectedSegments, embedder.segment(input, language).toArray()); - } - -} diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java new file mode 100644 index 00000000000..9599e60e556 --- /dev/null +++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java @@ -0,0 +1,59 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.tools; + +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Segmenter; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.Arrays; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tester of embedders. + * + * @author bratseth + */ +public class EmbedderTester { + + private final Embedder embedder; + + public EmbedderTester(Embedder embedder) { + this.embedder = embedder; + } + + /** + * Tests both embedding to a list of id's and encoding the same ids to a vector of the given type. + * + * @param expectedCodes all the expected codes of the given input, not including any trailing 0-paddings + * required for the tensor only + */ + public void assertEmbedded(String input, String tensorType, Integer... expectedCodes) { + TensorType type = TensorType.fromSpec(tensorType); + assertEquals(1, type.dimensions().size()); + assertTrue(type.dimensions().get(0).isIndexed()); + + int tensorSize = type.dimensions().get(0).size().get().intValue(); + + assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray()); + + var builder = Tensor.Builder.of(type); + for (int i = 0; i < tensorSize; i++) + builder.cell(i < expectedCodes.length ? expectedCodes[i] : 0, i); + assertEquals(builder.build(), embedder.embed(input, new Embedder.Context("destination"), type)); + } + + public void assertSegmented(String input, String... expectedSegments) { + assertSegmented(Language.UNKNOWN, input, expectedSegments); + } + + public void assertSegmented(Language language, String input, String... expectedSegments) { + assertArrayEquals(expectedSegments, ((Segmenter)embedder).segment(input, language).toArray()); + } + +} diff --git a/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java new file mode 100644 index 00000000000..4cbfe541327 --- /dev/null +++ b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java @@ -0,0 +1,38 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.wordpiece; + +import com.yahoo.config.FileReference; +import com.yahoo.language.tools.EmbedderTester; +import org.junit.Test; + +/** + * Tests the WordPiece embedder + * + * @author bratseth + */ +public class WordPieceEmbedderTest { + + private static final String vocabulary = "src/test/models/wordpiece/bert-base-uncased-vocab.txt"; + + @Test + public void testWordPieceEmbedder() { + var tester = new EmbedderTester(new WordPieceEmbedder.Builder(vocabulary).build()); + tester.assertEmbedded("what was the impact of the manhattan project", + "tensor(x[8])", + 2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622); + tester.assertSegmented("what was the impact of the manhattan project", + "what", "was", "the", "impact", "of", "the", "manhattan", "project"); + } + + @Test + public void testWordPieceEmbedderConfiguration() { + var config = new WordPieceConfig.Builder().model(new WordPieceConfig.Model.Builder() + .language("unknown") + .path(new FileReference(vocabulary))) + .build(); + var tester = new EmbedderTester(new WordPieceEmbedder(config)); + tester.assertSegmented("what was the impact of the manhattan project", + "what", "was", "the", "impact", "of", "the", "manhattan", "project"); + } + +} diff --git a/linguistics-components/src/test/models/bert/bert-base-uncased-vocab.txt b/linguistics-components/src/test/models/wordpiece/bert-base-uncased-vocab.txt index fb140275c15..fb140275c15 100644 --- a/linguistics-components/src/test/models/bert/bert-base-uncased-vocab.txt +++ b/linguistics-components/src/test/models/wordpiece/bert-base-uncased-vocab.txt |