summaryrefslogtreecommitdiffstats
path: root/linguistics-components
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-12-17 12:41:17 +0100
committerJon Bratseth <bratseth@gmail.com>2021-12-17 12:41:17 +0100
commit601b117281b74a578126a0f3effead55bc79c680 (patch)
tree29619184a8459763cc024b23e74960e6c9ec7f81 /linguistics-components
parent767cb63af0f530605180f5438767406e1db27520 (diff)
BERT -> WordPiece, make subword prefix configurable
Diffstat (limited to 'linguistics-components')
-rw-r--r--linguistics-components/abi-spec.json236
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java8
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java (renamed from linguistics-components/src/main/java/com/yahoo/language/bert/Model.java)29
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java (renamed from linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java)52
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java (renamed from linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java)2
-rw-r--r--linguistics-components/src/main/resources/configdefinitions/language.wordpiece.word-piece.def (renamed from linguistics-components/src/main/resources/configdefinitions/language.bert.bert.def)7
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java54
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java9
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java38
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java50
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java59
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java38
-rw-r--r--linguistics-components/src/test/models/wordpiece/bert-base-uncased-vocab.txt (renamed from linguistics-components/src/test/models/bert/bert-base-uncased-vocab.txt)0
13 files changed, 306 insertions, 276 deletions
diff --git a/linguistics-components/abi-spec.json b/linguistics-components/abi-spec.json
index 6dba8b602bd..39666fd93a3 100644
--- a/linguistics-components/abi-spec.json
+++ b/linguistics-components/abi-spec.json
@@ -1,5 +1,22 @@
{
- "com.yahoo.language.bert.BertConfig$Builder": {
+ "com.yahoo.language.sentencepiece.Scoring": {
+ "superClass": "java.lang.Enum",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final",
+ "enum"
+ ],
+ "methods": [
+ "public static com.yahoo.language.sentencepiece.Scoring[] values()",
+ "public static com.yahoo.language.sentencepiece.Scoring valueOf(java.lang.String)"
+ ],
+ "fields": [
+ "public static final enum com.yahoo.language.sentencepiece.Scoring highestScore",
+ "public static final enum com.yahoo.language.sentencepiece.Scoring fewestSegments"
+ ]
+ },
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Builder": {
"superClass": "java.lang.Object",
"interfaces": [
"com.yahoo.config.ConfigInstance$Builder"
@@ -9,23 +26,25 @@
],
"methods": [
"public void <init>()",
- "public void <init>(com.yahoo.language.bert.BertConfig)",
- "public com.yahoo.language.bert.BertConfig$Builder model(com.yahoo.language.bert.BertConfig$Model$Builder)",
- "public com.yahoo.language.bert.BertConfig$Builder model(java.util.function.Consumer)",
- "public com.yahoo.language.bert.BertConfig$Builder model(java.util.List)",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder collapseUnknowns(boolean)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder scoring(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.function.Consumer)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.List)",
"public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
"public final java.lang.String getDefMd5()",
"public final java.lang.String getDefName()",
"public final java.lang.String getDefNamespace()",
"public final boolean getApplyOnRestart()",
"public final void setApplyOnRestart(boolean)",
- "public com.yahoo.language.bert.BertConfig build()"
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig build()"
],
"fields": [
"public java.util.List model"
]
},
- "com.yahoo.language.bert.BertConfig$Model$Builder": {
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder": {
"superClass": "java.lang.Object",
"interfaces": [
"com.yahoo.config.ConfigBuilder"
@@ -35,14 +54,14 @@
],
"methods": [
"public void <init>()",
- "public void <init>(com.yahoo.language.bert.BertConfig$Model)",
- "public com.yahoo.language.bert.BertConfig$Model$Builder language(java.lang.String)",
- "public com.yahoo.language.bert.BertConfig$Model$Builder path(com.yahoo.config.FileReference)",
- "public com.yahoo.language.bert.BertConfig$Model build()"
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder language(java.lang.String)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder path(com.yahoo.config.FileReference)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model build()"
],
"fields": []
},
- "com.yahoo.language.bert.BertConfig$Model": {
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Model": {
"superClass": "com.yahoo.config.InnerNode",
"interfaces": [],
"attributes": [
@@ -50,13 +69,13 @@
"final"
],
"methods": [
- "public void <init>(com.yahoo.language.bert.BertConfig$Model$Builder)",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)",
"public java.lang.String language()",
"public java.nio.file.Path path()"
],
"fields": []
},
- "com.yahoo.language.bert.BertConfig$Producer": {
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Producer": {
"superClass": "java.lang.Object",
"interfaces": [
"com.yahoo.config.ConfigInstance$Producer"
@@ -67,11 +86,44 @@
"abstract"
],
"methods": [
- "public abstract void getConfig(com.yahoo.language.bert.BertConfig$Builder)"
+ "public abstract void getConfig(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)"
],
"fields": []
},
- "com.yahoo.language.bert.BertConfig": {
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum": {
+ "superClass": "java.lang.Enum",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final",
+ "enum"
+ ],
+ "methods": [
+ "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum[] values()",
+ "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum valueOf(java.lang.String)"
+ ],
+ "fields": [
+ "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore",
+ "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments"
+ ]
+ },
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring": {
+ "superClass": "com.yahoo.config.EnumNode",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final"
+ ],
+ "methods": [
+ "public void <init>()",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)"
+ ],
+ "fields": [
+ "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore",
+ "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments"
+ ]
+ },
+ "com.yahoo.language.sentencepiece.SentencePieceConfig": {
"superClass": "com.yahoo.config.ConfigInstance",
"interfaces": [],
"attributes": [
@@ -83,9 +135,11 @@
"public static java.lang.String getDefName()",
"public static java.lang.String getDefNamespace()",
"public static java.lang.String getDefVersion()",
- "public void <init>(com.yahoo.language.bert.BertConfig$Builder)",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)",
+ "public boolean collapseUnknowns()",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum scoring()",
"public java.util.List model()",
- "public com.yahoo.language.bert.BertConfig$Model model(int)"
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model model(int)"
],
"fields": [
"public static final java.lang.String CONFIG_DEF_MD5",
@@ -95,56 +149,47 @@
"public static final java.lang.String[] CONFIG_DEF_SCHEMA"
]
},
- "com.yahoo.language.bert.BertEmbedder$Builder": {
+ "com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder": {
"superClass": "java.lang.Object",
"interfaces": [],
"attributes": [
- "public"
+ "public",
+ "final"
],
"methods": [
"public void <init>()",
+ "public void <init>(java.lang.String)",
"public void addModel(com.yahoo.language.Language, java.nio.file.Path)",
- "public com.yahoo.language.bert.BertEmbedder$Builder addDefaultModel(java.nio.file.Path)",
+ "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder addDefaultModel(java.nio.file.Path)",
"public java.util.Map getModels()",
- "public com.yahoo.language.bert.BertEmbedder build()"
+ "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setCollapseUnknowns(boolean)",
+ "public boolean getCollapseUnknowns()",
+ "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setScoring(com.yahoo.language.sentencepiece.Scoring)",
+ "public com.yahoo.language.sentencepiece.Scoring getScoring()",
+ "public com.yahoo.language.sentencepiece.SentencePieceEmbedder build()"
],
"fields": []
},
- "com.yahoo.language.bert.BertEmbedder": {
+ "com.yahoo.language.sentencepiece.SentencePieceEmbedder": {
"superClass": "java.lang.Object",
"interfaces": [
- "com.yahoo.language.process.Embedder",
- "com.yahoo.language.process.Segmenter"
+ "com.yahoo.language.process.Segmenter",
+ "com.yahoo.language.process.Embedder"
],
"attributes": [
"public"
],
"methods": [
- "public void <init>(com.yahoo.language.bert.BertConfig)",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder)",
"public java.util.List segment(java.lang.String, com.yahoo.language.Language)",
"public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)",
- "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)"
+ "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)",
+ "public java.lang.String normalize(java.lang.String)"
],
"fields": []
},
- "com.yahoo.language.sentencepiece.Scoring": {
- "superClass": "java.lang.Enum",
- "interfaces": [],
- "attributes": [
- "public",
- "final",
- "enum"
- ],
- "methods": [
- "public static com.yahoo.language.sentencepiece.Scoring[] values()",
- "public static com.yahoo.language.sentencepiece.Scoring valueOf(java.lang.String)"
- ],
- "fields": [
- "public static final enum com.yahoo.language.sentencepiece.Scoring highestScore",
- "public static final enum com.yahoo.language.sentencepiece.Scoring fewestSegments"
- ]
- },
- "com.yahoo.language.sentencepiece.SentencePieceConfig$Builder": {
+ "com.yahoo.language.wordpiece.WordPieceConfig$Builder": {
"superClass": "java.lang.Object",
"interfaces": [
"com.yahoo.config.ConfigInstance$Builder"
@@ -154,25 +199,24 @@
],
"methods": [
"public void <init>()",
- "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder collapseUnknowns(boolean)",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder scoring(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.function.Consumer)",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.List)",
+ "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig)",
+ "public com.yahoo.language.wordpiece.WordPieceConfig$Builder subwordPrefix(java.lang.String)",
+ "public com.yahoo.language.wordpiece.WordPieceConfig$Builder model(com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder)",
+ "public com.yahoo.language.wordpiece.WordPieceConfig$Builder model(java.util.function.Consumer)",
+ "public com.yahoo.language.wordpiece.WordPieceConfig$Builder model(java.util.List)",
"public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
"public final java.lang.String getDefMd5()",
"public final java.lang.String getDefName()",
"public final java.lang.String getDefNamespace()",
"public final boolean getApplyOnRestart()",
"public final void setApplyOnRestart(boolean)",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig build()"
+ "public com.yahoo.language.wordpiece.WordPieceConfig build()"
],
"fields": [
"public java.util.List model"
]
},
- "com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder": {
+ "com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder": {
"superClass": "java.lang.Object",
"interfaces": [
"com.yahoo.config.ConfigBuilder"
@@ -182,14 +226,14 @@
],
"methods": [
"public void <init>()",
- "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model)",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder language(java.lang.String)",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder path(com.yahoo.config.FileReference)",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model build()"
+ "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig$Model)",
+ "public com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder language(java.lang.String)",
+ "public com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder path(com.yahoo.config.FileReference)",
+ "public com.yahoo.language.wordpiece.WordPieceConfig$Model build()"
],
"fields": []
},
- "com.yahoo.language.sentencepiece.SentencePieceConfig$Model": {
+ "com.yahoo.language.wordpiece.WordPieceConfig$Model": {
"superClass": "com.yahoo.config.InnerNode",
"interfaces": [],
"attributes": [
@@ -197,13 +241,13 @@
"final"
],
"methods": [
- "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)",
+ "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig$Model$Builder)",
"public java.lang.String language()",
"public java.nio.file.Path path()"
],
"fields": []
},
- "com.yahoo.language.sentencepiece.SentencePieceConfig$Producer": {
+ "com.yahoo.language.wordpiece.WordPieceConfig$Producer": {
"superClass": "java.lang.Object",
"interfaces": [
"com.yahoo.config.ConfigInstance$Producer"
@@ -214,44 +258,11 @@
"abstract"
],
"methods": [
- "public abstract void getConfig(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)"
+ "public abstract void getConfig(com.yahoo.language.wordpiece.WordPieceConfig$Builder)"
],
"fields": []
},
- "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum": {
- "superClass": "java.lang.Enum",
- "interfaces": [],
- "attributes": [
- "public",
- "final",
- "enum"
- ],
- "methods": [
- "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum[] values()",
- "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum valueOf(java.lang.String)"
- ],
- "fields": [
- "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore",
- "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments"
- ]
- },
- "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring": {
- "superClass": "com.yahoo.config.EnumNode",
- "interfaces": [],
- "attributes": [
- "public",
- "final"
- ],
- "methods": [
- "public void <init>()",
- "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)"
- ],
- "fields": [
- "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore",
- "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments"
- ]
- },
- "com.yahoo.language.sentencepiece.SentencePieceConfig": {
+ "com.yahoo.language.wordpiece.WordPieceConfig": {
"superClass": "com.yahoo.config.ConfigInstance",
"interfaces": [],
"attributes": [
@@ -263,11 +274,10 @@
"public static java.lang.String getDefName()",
"public static java.lang.String getDefNamespace()",
"public static java.lang.String getDefVersion()",
- "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)",
- "public boolean collapseUnknowns()",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum scoring()",
+ "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig$Builder)",
+ "public java.lang.String subwordPrefix()",
"public java.util.List model()",
- "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model model(int)"
+ "public com.yahoo.language.wordpiece.WordPieceConfig$Model model(int)"
],
"fields": [
"public static final java.lang.String CONFIG_DEF_MD5",
@@ -277,41 +287,39 @@
"public static final java.lang.String[] CONFIG_DEF_SCHEMA"
]
},
- "com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder": {
+ "com.yahoo.language.wordpiece.WordPieceEmbedder$Builder": {
"superClass": "java.lang.Object",
"interfaces": [],
"attributes": [
- "public"
+ "public",
+ "final"
],
"methods": [
"public void <init>()",
+ "public void <init>(java.lang.String)",
+ "public com.yahoo.language.wordpiece.WordPieceEmbedder$Builder setSubwordPrefix(java.lang.String)",
+ "public java.lang.String getSubwordPrefix()",
"public void addModel(com.yahoo.language.Language, java.nio.file.Path)",
- "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder addDefaultModel(java.nio.file.Path)",
+ "public com.yahoo.language.wordpiece.WordPieceEmbedder$Builder addDefaultModel(java.nio.file.Path)",
"public java.util.Map getModels()",
- "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setCollapseUnknowns(boolean)",
- "public boolean getCollapseUnknowns()",
- "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setScoring(com.yahoo.language.sentencepiece.Scoring)",
- "public com.yahoo.language.sentencepiece.Scoring getScoring()",
- "public com.yahoo.language.sentencepiece.SentencePieceEmbedder build()"
+ "public com.yahoo.language.wordpiece.WordPieceEmbedder build()"
],
"fields": []
},
- "com.yahoo.language.sentencepiece.SentencePieceEmbedder": {
+ "com.yahoo.language.wordpiece.WordPieceEmbedder": {
"superClass": "java.lang.Object",
"interfaces": [
- "com.yahoo.language.process.Segmenter",
- "com.yahoo.language.process.Embedder"
+ "com.yahoo.language.process.Embedder",
+ "com.yahoo.language.process.Segmenter"
],
"attributes": [
"public"
],
"methods": [
- "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)",
- "public void <init>(com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder)",
+ "public void <init>(com.yahoo.language.wordpiece.WordPieceConfig)",
"public java.util.List segment(java.lang.String, com.yahoo.language.Language)",
"public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)",
- "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)",
- "public java.lang.String normalize(java.lang.String)"
+ "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)"
],
"fields": []
}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index ff7f4ae42bc..31964eac514 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -10,6 +10,7 @@ import com.yahoo.language.process.Segmenter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
@@ -136,13 +137,16 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
return b.toString();
}
- public static class Builder {
+ public static final class Builder {
private final Map<Language, Path> models = new EnumMap<>(Language.class);
private boolean collapseUnknowns = true;
private Scoring scoring = Scoring.fewestSegments;
- public Builder() {
+ public Builder() {}
+
+ public Builder(String defaultModelFile) {
+ addDefaultModel(new File(defaultModelFile).toPath());
}
private Builder(SentencePieceConfig config) {
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java
index 54f37d597ce..ce996b85313 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/bert/Model.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java
@@ -1,5 +1,5 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.language.bert;
+package com.yahoo.language.wordpiece;
import com.yahoo.collections.Tuple2;
import com.yahoo.language.Language;
@@ -23,7 +23,7 @@ import java.util.TreeMap;
import java.util.stream.Collectors;
/**
- * A BERT embedder "model" - just a vocabulary of strings with a fixed id (index).
+ * A WordPiece embedder "model" - just a vocabulary of strings with a fixed id (index).
*
* Adapted from
* https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizer.java
@@ -34,12 +34,14 @@ import java.util.stream.Collectors;
*/
class Model {
- final Path source;
- final Language language;
+ private final String subwordPrefix;
+ private final Path source;
+ private final Language language;
private final NavigableMap<String, Integer> vocabulary;
private final Map<Integer, String> tokenId2Token;
- Model(Language language, Path path) {
+ Model(String subwordPrefix, Language language, Path path) {
+ this.subwordPrefix = subwordPrefix;
this.source = path;
this.language = language;
@@ -56,23 +58,25 @@ class Model {
}
}
catch (IOException e) {
- throw new IllegalArgumentException("Could not read a BERT model from " + path, e);
+ throw new IllegalArgumentException("Could not read a WordPiece model from " + path, e);
}
}
- public List<Integer> embed(String text, Tokenizer tokenizer) {
+ Language language() { return language; }
+
+ List<Integer> embed(String text, Tokenizer tokenizer) {
List<Integer> ids = new ArrayList<>();
text = text.toLowerCase();
for (Token t : tokenizer.tokenize(text, language, StemMode.NONE, true)) {
String originalToken = t.getTokenString();
String candidate = originalToken;
int count = 0;
- while (candidate.length() > 0 && !"##".equals(candidate)) {
+ while (candidate.length() > 0 && !candidate.equals(subwordPrefix)) {
Tuple2<String, Integer> entry = findLongestSubstring(candidate);
if (entry == null) break;
ids.add(entry.second);
- candidate = "##" + candidate.substring(entry.first.length());
+ candidate = subwordPrefix + candidate.substring(entry.first.length());
if (count++ > originalToken.length()) break;
}
}
@@ -80,7 +84,7 @@ class Model {
return ids;
}
- public List<String> segment(String text, Tokenizer tokenizer) {
+ List<String> segment(String text, Tokenizer tokenizer) {
return embed(text, tokenizer).stream().map(tokenId -> tokenId2Token.get(tokenId)).collect(Collectors.toList());
}
@@ -102,4 +106,9 @@ class Model {
return new Tuple2<>(longestSubstring, id);
}
+ @Override
+ public String toString() {
+ return "WordPiece model for " + language + ": '" + source + "'";
+ }
+
}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java
index c2b19391e74..08de05f351a 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/bert/BertEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/WordPieceEmbedder.java
@@ -1,5 +1,5 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.language.bert;
+package com.yahoo.language.wordpiece;
import com.google.inject.Inject;
import com.yahoo.language.tools.Embed;
@@ -10,7 +10,9 @@ import com.yahoo.language.process.Tokenizer;
import com.yahoo.language.simple.SimpleLinguistics;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.language.wordpiece.WordPieceConfig;
+import java.io.File;
import java.nio.file.Path;
import java.util.EnumMap;
import java.util.List;
@@ -18,35 +20,37 @@ import java.util.Map;
import java.util.stream.Collectors;
/**
- * An embedder to use with BERT models: Text is tokenized into tokens from a configured vocabulary,
+ * An implementation of the WordPiece embedder, usually used with BERT models,
+ * see https://arxiv.org/pdf/1609.08144v2.pdf
+ * Text is tokenized into tokens from a configured vocabulary,
* and optionally returned as a 1-d dense tensor of token ids.
*
* @author bratseth
*/
-public class BertEmbedder implements Embedder, Segmenter {
+public class WordPieceEmbedder implements Embedder, Segmenter {
private final Map<Language, Model> models;
private final Tokenizer tokenizer;
@Inject
- public BertEmbedder(BertConfig config) {
+ public WordPieceEmbedder(WordPieceConfig config) {
this(new Builder(config));
}
- private BertEmbedder(Builder builder) {
+ private WordPieceEmbedder(Builder builder) {
super();
this.tokenizer = new SimpleLinguistics().getTokenizer(); // always just split on spaces etc. and lowercase
models = builder.getModels().entrySet()
.stream()
- .map(e -> new Model(e.getKey(), e.getValue()))
- .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
+ .map(e -> new Model(builder.getSubwordPrefix(), e.getKey(), e.getValue()))
+ .collect(Collectors.toUnmodifiableMap(m -> m.language(), m -> m));
if (models.isEmpty())
- throw new IllegalArgumentException("BertEmbedder requires at least one model configured");
+ throw new IllegalArgumentException("WordPieceEmbedder requires at least one model configured");
}
/**
- * Segments the given text into token segments from the BERT vocabulary.
+ * Segments the given text into token segments from the WordPiece vocabulary.
*
* @param text the text to segment. The text should be of a language using space-separated words.
* @return the list of zero or more token ids resulting from segmenting the input text
@@ -57,7 +61,7 @@ public class BertEmbedder implements Embedder, Segmenter {
}
/**
- * Segments the given text into token segments from the BERT vocabulary and returns the token ids.
+ * Segments the given text into token segments from the WordPiece vocabulary and returns the token ids.
*
* @param text the text to segment. The text should be of a language using space-separated words.
* @param context the context which specifies the language used to select a model
@@ -90,21 +94,33 @@ public class BertEmbedder implements Embedder, Segmenter {
// Disregard language if there is default model
if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
if (models.containsKey(language)) return models.get(language);
- throw new IllegalArgumentException("No BERT model for language " + language + " is configured");
+ throw new IllegalArgumentException("No WordPiece model for language " + language + " is configured");
}
- public static class Builder {
+ public static final class Builder {
+ private String subwordPrefix = "##";
private final Map<Language, Path> models = new EnumMap<>(Language.class);
- public Builder() {
+ public Builder() {}
+
+ public Builder(String defaultModelFile) {
+ addDefaultModel(new File(defaultModelFile).toPath());
}
- private Builder(BertConfig config) {
- for (BertConfig.Model model : config.model())
+ private Builder(WordPieceConfig config) {
+ this.subwordPrefix = config.subwordPrefix();
+ for (WordPieceConfig.Model model : config.model())
addModel(Language.fromLanguageTag(model.language()), model.path());
}
+ public Builder setSubwordPrefix(String prefix) {
+ this.subwordPrefix = subwordPrefix;
+ return this;
+ }
+
+ public String getSubwordPrefix() { return subwordPrefix; }
+
public void addModel(Language language, Path model) {
models.put(language, model);
}
@@ -113,16 +129,16 @@ public class BertEmbedder implements Embedder, Segmenter {
* Adds the model that will be used if the language is unknown, OR only one model is specified.
* The same as addModel(Language.UNKNOWN, model).
*/
- public BertEmbedder.Builder addDefaultModel(Path model) {
+ public WordPieceEmbedder.Builder addDefaultModel(Path model) {
addModel(Language.UNKNOWN, model);
return this;
}
public Map<Language, Path> getModels() { return models; }
- public BertEmbedder build() {
+ public WordPieceEmbedder build() {
if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied");
- return new BertEmbedder(this);
+ return new WordPieceEmbedder(this);
}
}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java
index e3f612f4114..0bbb6f001f5 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/bert/package-info.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/package-info.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
@ExportPackage
@PublicApi
-package com.yahoo.language.bert;
+package com.yahoo.language.wordpiece;
import com.yahoo.api.annotations.PublicApi;
import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/linguistics-components/src/main/resources/configdefinitions/language.bert.bert.def b/linguistics-components/src/main/resources/configdefinitions/language.wordpiece.word-piece.def
index 86d338758d0..08592250eb5 100644
--- a/linguistics-components/src/main/resources/configdefinitions/language.bert.bert.def
+++ b/linguistics-components/src/main/resources/configdefinitions/language.wordpiece.word-piece.def
@@ -1,8 +1,11 @@
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-# Configures com.yahoo.language.bert.BertEmbedder
+# Configures com.yahoo.language.wordpiece.WordPieceEmbedder
-namespace=language.bert
+namespace=language.wordpiece
+
+# The prefix to prepend to subword tokens
+subwordPrefix string default="##"
# The language a model is for, one of the language tags in com.yahoo.language.Language.
# Use "unknown" for a model to be used for any language (i.e by default).
diff --git a/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java b/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java
deleted file mode 100644
index 1bc25e0d217..00000000000
--- a/linguistics-components/src/test/java/com/yahoo/language/bert/BertEmbedderTest.java
+++ /dev/null
@@ -1,54 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.language.bert;
-
-import com.yahoo.config.FileReference;
-import com.yahoo.language.Language;
-import com.yahoo.language.process.Embedder;
-import com.yahoo.language.simple.SimpleLinguistics;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-import org.junit.Test;
-
-import java.io.File;
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-
-/**
- * Tests the BERT embedder
- *
- * @author bratseth
- */
-public class BertEmbedderTest {
-
- private static final String vocabulary = "src/test/models/bert/bert-base-uncased-vocab.txt";
-
- @Test
- public void testBertEmbedder() {
- var embedder = new BertEmbedder.Builder().addDefaultModel(new File(vocabulary).toPath()).build();
- var expectedTokenIds = List.of(2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622);
- assertEquals(expectedTokenIds, embedder.embed("what was the impact of the manhattan project",
- new Embedder.Context("destination")));
-
- var expectedTokens = List.of("what", "was", "the", "impact", "of", "the", "manhattan", "project");
- assertEquals(expectedTokens, embedder.segment("what was the impact of the manhattan project",
- Language.ENGLISH));
-
- var expectedDenseTensor = Tensor.from("tensor(x[8]):" + expectedTokenIds);
- assertEquals(expectedDenseTensor, embedder.embed("what was the impact of the manhattan project",
- new Embedder.Context("destination"),
- expectedDenseTensor.type()));
- }
-
- @Test
- public void testBertEmbedderConfiguration() {
- var config = new BertConfig.Builder().model(new BertConfig.Model.Builder().language("unknown")
- .path(new FileReference(vocabulary)))
- .build();
- var embedder = new BertEmbedder(config);
- var expectedTokenIds = List.of(2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622);
- assertEquals(expectedTokenIds, embedder.embed("what was the impact of the manhattan project",
- new Embedder.Context("destination")));
- }
-
-}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
index 1ed2271f774..19cb2267655 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
@@ -4,6 +4,7 @@ package com.yahoo.language.sentencepiece;
import com.yahoo.config.FileReference;
import com.yahoo.language.Language;
+import com.yahoo.language.tools.EmbedderTester;
import org.junit.Test;
/**
@@ -15,7 +16,7 @@ public class SentencePieceConfigurationTest {
public void testEnglishTokenization() {
var b = new SentencePieceConfig.Builder();
addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
- var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
+ var tester = new EmbedderTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence");
tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo");
}
@@ -25,7 +26,7 @@ public class SentencePieceConfigurationTest {
var b = new SentencePieceConfig.Builder();
addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
b.collapseUnknowns(false);
- var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
+ var tester = new EmbedderTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
}
@@ -34,7 +35,7 @@ public class SentencePieceConfigurationTest {
var b = new SentencePieceConfig.Builder();
addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
b.scoring(SentencePieceConfig.Scoring.highestScore);
- var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
+ var tester = new EmbedderTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented("hello", "▁h", "el", "lo");
}
@@ -43,7 +44,7 @@ public class SentencePieceConfigurationTest {
var b = new SentencePieceConfig.Builder();
addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b);
addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
- var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
+ var tester = new EmbedderTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index 8b3e2988c43..2fbafb23485 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -3,6 +3,7 @@
package com.yahoo.language.sentencepiece;
import com.yahoo.language.Language;
+import com.yahoo.language.tools.EmbedderTester;
import org.junit.Test;
import java.io.File;
@@ -13,8 +14,8 @@ import java.io.File;
public class SentencePieceTest {
@Test
- public void testEnglishTokenization() {
- var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ public void testEnglishSegmenting() {
+ var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build());
tester.assertSegmented("h", "▁h");
tester.assertSegmented("he", "▁he");
tester.assertSegmented("hel", "▁hel");
@@ -36,33 +37,28 @@ public class SentencePieceTest {
}
@Test
- public void testIntegerListEncoding() {
- var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
- tester.assertEmbedded("hello, world!", 908, 1418, 9934, 501, 9960);
- tester.assertEmbedded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960);
- }
-
- @Test
- public void testDenseTensorEncoding() {
- var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
- tester.assertEmbedded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]");
- tester.assertEmbedded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]");
- tester.assertEmbedded("hello, world!", "tensor(d[2])", "[908,1418]");
+ public void testEnglishEmbedding() {
+ var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build());
+ tester.assertEmbedded("hello, world!", "tensor(d[10])", 908, 1418, 9934, 501, 9960);
+ tester.assertEmbedded("Hello, world!", "tensor(d[10])", 9912, 0, 6595, 9934, 501, 9960);
+ tester.assertEmbedded("hello, world!", "tensor(d[2])", 908, 1418, 9934, 501, 9960);
}
@Test
public void testNoCollapse() {
- var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder()
- .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
- .setCollapseUnknowns(false));
+ var builder = new SentencePieceEmbedder.Builder()
+ .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
+ .setCollapseUnknowns(false);
+ var tester = new EmbedderTester(builder.build());
tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
}
@Test
public void testHighestScore() {
- var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder()
- .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
- .setScoring(Scoring.highestScore));
+ var builder = new SentencePieceEmbedder.Builder()
+ .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
+ .setScoring(Scoring.highestScore);
+ var tester = new EmbedderTester(builder.build());
tester.assertSegmented("h", "▁h");
tester.assertSegmented("he", "▁he");
tester.assertSegmented("hel", "▁h", "el");
@@ -74,7 +70,7 @@ public class SentencePieceTest {
SentencePieceEmbedder.Builder builder = new SentencePieceEmbedder.Builder();
builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath());
builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
- var tester = new SentencePieceTester(builder);
+ var tester = new EmbedderTester(builder.build());
tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
deleted file mode 100644
index 4dae53c60df..00000000000
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-//
-
-package com.yahoo.language.sentencepiece;
-
-import com.yahoo.language.Language;
-import com.yahoo.language.process.Embedder;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.nio.file.Path;
-
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-
-class SentencePieceTester {
-
- private final SentencePieceEmbedder embedder;
-
- public SentencePieceTester(Path model) {
- this(new SentencePieceEmbedder.Builder().addDefaultModel(model));
- }
-
- public SentencePieceTester(SentencePieceEmbedder.Builder builder) {
- this(builder.build());
- }
-
- public SentencePieceTester(SentencePieceEmbedder embedder) {
- this.embedder = embedder;
- }
-
- public void assertEmbedded(String input, Integer... expectedCodes) {
- assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray());
- }
-
- public void assertEmbedded(String input, String tensorType, String tensor) {
- TensorType type = TensorType.fromSpec(tensorType);
- Tensor expected = Tensor.from(type, tensor);
- assertEquals(expected, embedder.embed(input, new Embedder.Context("test"), type));
- }
-
- public void assertSegmented(String input, String... expectedSegments) {
- assertSegmented(Language.UNKNOWN, input, expectedSegments);
- }
-
- public void assertSegmented(Language language, String input, String... expectedSegments) {
- assertArrayEquals(expectedSegments, embedder.segment(input, language).toArray());
- }
-
-}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
new file mode 100644
index 00000000000..9599e60e556
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
@@ -0,0 +1,59 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.tools;
+
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.language.process.Segmenter;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tester of embedders.
+ *
+ * @author bratseth
+ */
+public class EmbedderTester {
+
+ private final Embedder embedder;
+
+ public EmbedderTester(Embedder embedder) {
+ this.embedder = embedder;
+ }
+
+ /**
+ * Tests both embedding to a list of id's and encoding the same ids to a vector of the given type.
+ *
+ * @param expectedCodes all the expected codes of the given input, not including any trailing 0-paddings
+ * required for the tensor only
+ */
+ public void assertEmbedded(String input, String tensorType, Integer... expectedCodes) {
+ TensorType type = TensorType.fromSpec(tensorType);
+ assertEquals(1, type.dimensions().size());
+ assertTrue(type.dimensions().get(0).isIndexed());
+
+ int tensorSize = type.dimensions().get(0).size().get().intValue();
+
+ assertArrayEquals(expectedCodes, embedder.embed(input, new Embedder.Context("test")).toArray());
+
+ var builder = Tensor.Builder.of(type);
+ for (int i = 0; i < tensorSize; i++)
+ builder.cell(i < expectedCodes.length ? expectedCodes[i] : 0, i);
+ assertEquals(builder.build(), embedder.embed(input, new Embedder.Context("destination"), type));
+ }
+
+ public void assertSegmented(String input, String... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+
+ public void assertSegmented(Language language, String input, String... expectedSegments) {
+ assertArrayEquals(expectedSegments, ((Segmenter)embedder).segment(input, language).toArray());
+ }
+
+}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java
new file mode 100644
index 00000000000..4cbfe541327
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/wordpiece/WordPieceEmbedderTest.java
@@ -0,0 +1,38 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.wordpiece;
+
+import com.yahoo.config.FileReference;
+import com.yahoo.language.tools.EmbedderTester;
+import org.junit.Test;
+
+/**
+ * Tests the WordPiece embedder
+ *
+ * @author bratseth
+ */
+public class WordPieceEmbedderTest {
+
+ private static final String vocabulary = "src/test/models/wordpiece/bert-base-uncased-vocab.txt";
+
+ @Test
+ public void testWordPieceEmbedder() {
+ var tester = new EmbedderTester(new WordPieceEmbedder.Builder(vocabulary).build());
+ tester.assertEmbedded("what was the impact of the manhattan project",
+ "tensor(x[8])",
+ 2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622);
+ tester.assertSegmented("what was the impact of the manhattan project",
+ "what", "was", "the", "impact", "of", "the", "manhattan", "project");
+ }
+
+ @Test
+ public void testWordPieceEmbedderConfiguration() {
+ var config = new WordPieceConfig.Builder().model(new WordPieceConfig.Model.Builder()
+ .language("unknown")
+ .path(new FileReference(vocabulary)))
+ .build();
+ var tester = new EmbedderTester(new WordPieceEmbedder(config));
+ tester.assertSegmented("what was the impact of the manhattan project",
+ "what", "was", "the", "impact", "of", "the", "manhattan", "project");
+ }
+
+}
diff --git a/linguistics-components/src/test/models/bert/bert-base-uncased-vocab.txt b/linguistics-components/src/test/models/wordpiece/bert-base-uncased-vocab.txt
index fb140275c15..fb140275c15 100644
--- a/linguistics-components/src/test/models/bert/bert-base-uncased-vocab.txt
+++ b/linguistics-components/src/test/models/wordpiece/bert-base-uncased-vocab.txt