summaryrefslogtreecommitdiffstats
path: root/linguistics
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-16 11:04:56 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-16 11:04:56 +0200
commita2afdafbffcbc09594fd629c65746ec253f180be (patch)
tree987804d4362c801c6c47cfe469ec6b97409de6ef /linguistics
parent381033510b992049d55cae9964d942b4b47eb9df (diff)
Make SentencePieceEncoder configurable
Diffstat (limited to 'linguistics')
-rw-r--r--linguistics/abi-spec.json153
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java36
-rw-r--r--linguistics/src/main/resources/configdefinitions/sentence-piece.def18
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java59
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java33
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java40
6 files changed, 283 insertions, 56 deletions
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index f410e83645e..8df0848870e 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -701,6 +701,138 @@
],
"fields": []
},
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Builder": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "com.yahoo.config.ConfigInstance$Builder"
+ ],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>()",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder collapseUnknowns(boolean)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder scoring(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.List)",
+ "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
+ "public final java.lang.String getDefMd5()",
+ "public final java.lang.String getDefName()",
+ "public final java.lang.String getDefNamespace()",
+ "public final boolean getApplyOnRestart()",
+ "public final void setApplyOnRestart(boolean)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig build()"
+ ],
+ "fields": [
+ "public java.util.List model"
+ ]
+ },
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "com.yahoo.config.ConfigBuilder"
+ ],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>()",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder language(java.lang.String)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder path(com.yahoo.config.FileReference)",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model build()"
+ ],
+ "fields": []
+ },
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Model": {
+ "superClass": "com.yahoo.config.InnerNode",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)",
+ "public java.lang.String language()",
+ "public java.nio.file.Path path()"
+ ],
+ "fields": []
+ },
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Producer": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "com.yahoo.config.ConfigInstance$Producer"
+ ],
+ "attributes": [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods": [
+ "public abstract void getConfig(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)"
+ ],
+ "fields": []
+ },
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum": {
+ "superClass": "java.lang.Enum",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final",
+ "enum"
+ ],
+ "methods": [
+ "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum[] values()",
+ "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum valueOf(java.lang.String)"
+ ],
+ "fields": [
+ "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore",
+ "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments"
+ ]
+ },
+ "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring": {
+ "superClass": "com.yahoo.config.EnumNode",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final"
+ ],
+ "methods": [
+ "public void <init>()",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)"
+ ],
+ "fields": [
+ "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore",
+ "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments"
+ ]
+ },
+ "com.yahoo.language.sentencepiece.SentencePieceConfig": {
+ "superClass": "com.yahoo.config.ConfigInstance",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final"
+ ],
+ "methods": [
+ "public static java.lang.String getDefMd5()",
+ "public static java.lang.String getDefName()",
+ "public static java.lang.String getDefNamespace()",
+ "public static java.lang.String getDefVersion()",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)",
+ "public boolean collapseUnknowns()",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum scoring()",
+ "public java.util.List model()",
+ "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model model(int)"
+ ],
+ "fields": [
+ "public static final java.lang.String CONFIG_DEF_MD5",
+ "public static final java.lang.String CONFIG_DEF_NAME",
+ "public static final java.lang.String CONFIG_DEF_NAMESPACE",
+ "public static final java.lang.String CONFIG_DEF_VERSION",
+ "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
+ ]
+ },
"com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder": {
"superClass": "java.lang.Object",
"interfaces": [],
@@ -737,26 +869,6 @@
"public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring fewestSegments"
]
},
- "com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType": {
- "superClass": "java.lang.Enum",
- "interfaces": [],
- "attributes": [
- "public",
- "final",
- "enum"
- ],
- "methods": [
- "public static com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType[] values()",
- "public static com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType valueOf(java.lang.String)"
- ],
- "fields": [
- "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType text",
- "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType control",
- "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType userDefined",
- "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType unknown",
- "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType unused"
- ]
- },
"com.yahoo.language.sentencepiece.SentencePieceEncoder": {
"superClass": "java.lang.Object",
"interfaces": [
@@ -766,6 +878,7 @@
"public"
],
"methods": [
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)",
"public void <init>(com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder)",
"public java.util.List segment(java.lang.String, com.yahoo.language.Language)",
"public java.util.List encode(java.lang.String, com.yahoo.language.Language)",
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
index a755a9e6ff3..4bf808bec0c 100644
--- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
@@ -3,6 +3,7 @@
package com.yahoo.language.sentencepiece;
import com.google.common.annotations.Beta;
+import com.google.inject.Inject;
import com.yahoo.io.IOUtils;
import com.yahoo.language.Language;
import com.yahoo.language.process.Segmenter;
@@ -30,8 +31,7 @@ import java.util.stream.Collectors;
public class SentencePieceEncoder implements Segmenter {
// TODO: Support characters beyond BMP
-
- public enum TokenType { text, control, userDefined, unknown, unused }
+ enum TokenType { text, control, userDefined, unknown, unused }
/** The scoring strategy to use for picking segments */
public enum Scoring {
@@ -48,6 +48,11 @@ public class SentencePieceEncoder implements Segmenter {
private final Map<Language, Model> models;
+ @Inject
+ public SentencePieceEncoder(SentencePieceConfig config) {
+ this(new Builder(config));
+ }
+
public SentencePieceEncoder(Builder builder) {
collapseUnknowns = builder.getCollapseUnknowns();
scoring = builder.getScoring();
@@ -56,6 +61,8 @@ public class SentencePieceEncoder implements Segmenter {
.stream()
.map(e -> new Model(e.getKey(), e.getValue()))
.collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
+ if (models.isEmpty())
+ throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured");
}
/**
@@ -250,6 +257,7 @@ public class SentencePieceEncoder implements Segmenter {
private static final class Model {
+ final Path source;
final Language language;
final float minScore;
final float maxScore;
@@ -257,6 +265,7 @@ public class SentencePieceEncoder implements Segmenter {
Model(Language language, Path path) {
try {
+ this.source = path;
this.language = language;
var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile()));
float minScore = Float.MAX_VALUE;
@@ -271,10 +280,15 @@ public class SentencePieceEncoder implements Segmenter {
this.maxScore = maxScore;
}
catch (IOException e) {
- throw new IllegalArgumentException("Could not read a SentencePiece model from '" + path + "'", e);
+ throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e);
}
}
+ @Override
+ public String toString() {
+ return "SentencePiece model for " + language + ": '" + source + "'";
+ }
+
}
public static class Builder {
@@ -283,6 +297,18 @@ public class SentencePieceEncoder implements Segmenter {
private boolean collapseUnknowns = true;
private Scoring scoring = Scoring.fewestSegments;
+ public Builder() {
+ }
+
+ private Builder(SentencePieceConfig config) {
+ collapseUnknowns = config.collapseUnknowns();
+ scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments
+ : Scoring.highestScore;
+ for (SentencePieceConfig.Model model : config.model()) {
+ addModel(Language.fromLanguageTag(model.language()), model.path());
+ }
+ }
+
public void addModel(Language language, Path model) {
models.put(language, model);
}
@@ -307,9 +333,7 @@ public class SentencePieceEncoder implements Segmenter {
}
public boolean getCollapseUnknowns() { return collapseUnknowns; }
- /**
- * Sets the scoring strategy to use when picking a segmentation. Default: fewestTokens.
- */
+ /** Sets the scoring strategy to use when picking a segmentation. Default: fewestSegments. */
public Builder setScoring(Scoring scoring) {
this.scoring = scoring;
return this;
diff --git a/linguistics/src/main/resources/configdefinitions/sentence-piece.def b/linguistics/src/main/resources/configdefinitions/sentence-piece.def
new file mode 100644
index 00000000000..b91c0c45dc4
--- /dev/null
+++ b/linguistics/src/main/resources/configdefinitions/sentence-piece.def
@@ -0,0 +1,18 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+# Configures com.yahoo.language.sentencepiece.SentencePieceEncoder
+
+namespace=language.sentencepiece
+
+# Whether consecutive unknown character should be collapsed into one large unknown token (default
+# or be returned as single character tokens.
+collapseUnknowns bool default=true
+
+# The scoring strategy to use when picking a segmentation.
+scoring enum { highestScore, fewestSegments } default=fewestSegments
+
+# The language a model is for, one of the language tags in com.yahoo.language.Language.
+# Use "unknown" for models to be used with any language.
+model[].language string
+# The path to the model relative to the application package root
+model[].path path \ No newline at end of file
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
new file mode 100644
index 00000000000..edbbe21ec53
--- /dev/null
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
@@ -0,0 +1,59 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.config.FileReference;
+import com.yahoo.language.Language;
+import org.junit.Test;
+
+/**
+ * @author bratseth
+ */
+public class SentencePieceConfigurationTest {
+
+ @Test
+ public void testEnglishTokenization() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence");
+ tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo");
+ }
+
+ @Test
+ public void testNoCollapse() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ b.collapseUnknowns(false);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
+ }
+
+ @Test
+ public void testHighestScore() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ b.scoring(SentencePieceConfig.Scoring.highestScore);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented("hello", "▁h", "el", "lo");
+ }
+
+ @Test
+ public void testMultiLanguageTokenization() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b);
+ addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
+ tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
+ tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
+ }
+
+ private void addModel(String language, String file, SentencePieceConfig.Builder b) {
+ var mb = new SentencePieceConfig.Model.Builder();
+ mb.language(language);
+ mb.path(new FileReference(file));
+ b.model(mb);
+ }
+
+}
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index 7d0c1c5c78e..f86bc2f716b 100644
--- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -6,10 +6,6 @@ import com.yahoo.language.Language;
import org.junit.Test;
import java.io.File;
-import java.io.IOException;
-import java.nio.file.Path;
-
-import static org.junit.Assert.assertArrayEquals;
/**
* @author bratseth
@@ -61,37 +57,14 @@ public class SentencePieceTest {
}
@Test
- public void testJapaneseTokenization() throws IOException {
+ public void testMultiLanguageTokenization() {
SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder();
builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath());
builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
var tester = new SentencePieceTester(builder);
tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
- }
-
- private static class SentencePieceTester {
-
- private final SentencePieceEncoder encoder;
-
- public SentencePieceTester(Path model) {
- this(new SentencePieceEncoder.Builder().addDefaultModel(model));
- }
-
- public SentencePieceTester(SentencePieceEncoder.Builder builder) {
- encoder = builder.build();
- }
-
- private void assertEncoded(String input, Integer ... expectedCodes) {
- assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
- }
-
- private void assertSegmented(String input, String ... expectedSegments) {
- assertSegmented(Language.UNKNOWN, input, expectedSegments);
- }
- private void assertSegmented(Language language, String input, String ... expectedSegments) {
- assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
- }
-
+ tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
+ tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
}
}
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
new file mode 100644
index 00000000000..dee9be5aa7e
--- /dev/null
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -0,0 +1,40 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+//
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.language.Language;
+
+import java.nio.file.Path;
+
+import static org.junit.Assert.assertArrayEquals;
+
+class SentencePieceTester {
+
+ private final SentencePieceEncoder encoder;
+
+ public SentencePieceTester(Path model) {
+ this(new SentencePieceEncoder.Builder().addDefaultModel(model));
+ }
+
+ public SentencePieceTester(SentencePieceEncoder.Builder builder) {
+ this(builder.build());
+ }
+
+ public SentencePieceTester(SentencePieceEncoder encoder) {
+ this.encoder = encoder;
+ }
+
+ public void assertEncoded(String input, Integer... expectedCodes) {
+ assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
+ }
+
+ public void assertSegmented(String input, String... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+
+ public void assertSegmented(Language language, String input, String... expectedSegments) {
+ assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
+ }
+
+}