From a2afdafbffcbc09594fd629c65746ec253f180be Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 16 Sep 2021 11:04:56 +0200 Subject: Make SentencePieceEncoder configurable --- linguistics/abi-spec.json | 153 ++++++++++++++++++--- .../sentencepiece/SentencePieceEncoder.java | 36 ++++- .../resources/configdefinitions/sentence-piece.def | 18 +++ .../SentencePieceConfigurationTest.java | 59 ++++++++ .../language/sentencepiece/SentencePieceTest.java | 33 +---- .../sentencepiece/SentencePieceTester.java | 40 ++++++ 6 files changed, 283 insertions(+), 56 deletions(-) create mode 100644 linguistics/src/main/resources/configdefinitions/sentence-piece.def create mode 100644 linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java create mode 100644 linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java (limited to 'linguistics') diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index f410e83645e..8df0848870e 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -701,6 +701,138 @@ ], "fields": [] }, + "com.yahoo.language.sentencepiece.SentencePieceConfig$Builder": { + "superClass": "java.lang.Object", + "interfaces": [ + "com.yahoo.config.ConfigInstance$Builder" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void ()", + "public void (com.yahoo.language.sentencepiece.SentencePieceConfig)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder collapseUnknowns(boolean)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder scoring(com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Builder model(java.util.List)", + "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", + "public final java.lang.String getDefMd5()", + "public final java.lang.String getDefName()", + "public final java.lang.String getDefNamespace()", + "public final boolean getApplyOnRestart()", + "public final void setApplyOnRestart(boolean)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig build()" + ], + "fields": [ + "public java.util.List model" + ] + }, + "com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder": { + "superClass": "java.lang.Object", + "interfaces": [ + "com.yahoo.config.ConfigBuilder" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void ()", + "public void (com.yahoo.language.sentencepiece.SentencePieceConfig$Model)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder language(java.lang.String)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder path(com.yahoo.config.FileReference)", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model build()" + ], + "fields": [] + }, + "com.yahoo.language.sentencepiece.SentencePieceConfig$Model": { + "superClass": "com.yahoo.config.InnerNode", + "interfaces": [], + "attributes": [ + "public", + "final" + ], + "methods": [ + "public void (com.yahoo.language.sentencepiece.SentencePieceConfig$Model$Builder)", + "public java.lang.String language()", + "public java.nio.file.Path path()" + ], + "fields": [] + }, + "com.yahoo.language.sentencepiece.SentencePieceConfig$Producer": { + "superClass": "java.lang.Object", + "interfaces": [ + "com.yahoo.config.ConfigInstance$Producer" + ], + "attributes": [ + "public", + "interface", + "abstract" + ], + "methods": [ + "public abstract void getConfig(com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)" + ], + "fields": [] + }, + "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum": { + "superClass": "java.lang.Enum", + "interfaces": [], + "attributes": [ + "public", + "final", + "enum" + ], + "methods": [ + "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum[] values()", + "public static com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum valueOf(java.lang.String)" + ], + "fields": [ + "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore", + "public static final enum com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments" + ] + }, + "com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring": { + "superClass": "com.yahoo.config.EnumNode", + "interfaces": [], + "attributes": [ + "public", + "final" + ], + "methods": [ + "public void ()", + "public void (com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum)" + ], + "fields": [ + "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum highestScore", + "public static final com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum fewestSegments" + ] + }, + "com.yahoo.language.sentencepiece.SentencePieceConfig": { + "superClass": "com.yahoo.config.ConfigInstance", + "interfaces": [], + "attributes": [ + "public", + "final" + ], + "methods": [ + "public static java.lang.String getDefMd5()", + "public static java.lang.String getDefName()", + "public static java.lang.String getDefNamespace()", + "public static java.lang.String getDefVersion()", + "public void (com.yahoo.language.sentencepiece.SentencePieceConfig$Builder)", + "public boolean collapseUnknowns()", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Scoring$Enum scoring()", + "public java.util.List model()", + "public com.yahoo.language.sentencepiece.SentencePieceConfig$Model model(int)" + ], + "fields": [ + "public static final java.lang.String CONFIG_DEF_MD5", + "public static final java.lang.String CONFIG_DEF_NAME", + "public static final java.lang.String CONFIG_DEF_NAMESPACE", + "public static final java.lang.String CONFIG_DEF_VERSION", + "public static final java.lang.String[] CONFIG_DEF_SCHEMA" + ] + }, "com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder": { "superClass": "java.lang.Object", "interfaces": [], @@ -737,26 +869,6 @@ "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring fewestSegments" ] }, - "com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType": { - "superClass": "java.lang.Enum", - "interfaces": [], - "attributes": [ - "public", - "final", - "enum" - ], - "methods": [ - "public static com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType[] values()", - "public static com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType valueOf(java.lang.String)" - ], - "fields": [ - "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType text", - "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType control", - "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType userDefined", - "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType unknown", - "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$TokenType unused" - ] - }, "com.yahoo.language.sentencepiece.SentencePieceEncoder": { "superClass": "java.lang.Object", "interfaces": [ @@ -766,6 +878,7 @@ "public" ], "methods": [ + "public void (com.yahoo.language.sentencepiece.SentencePieceConfig)", "public void (com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder)", "public java.util.List segment(java.lang.String, com.yahoo.language.Language)", "public java.util.List encode(java.lang.String, com.yahoo.language.Language)", diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java index a755a9e6ff3..4bf808bec0c 100644 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -3,6 +3,7 @@ package com.yahoo.language.sentencepiece; import com.google.common.annotations.Beta; +import com.google.inject.Inject; import com.yahoo.io.IOUtils; import com.yahoo.language.Language; import com.yahoo.language.process.Segmenter; @@ -30,8 +31,7 @@ import java.util.stream.Collectors; public class SentencePieceEncoder implements Segmenter { // TODO: Support characters beyond BMP - - public enum TokenType { text, control, userDefined, unknown, unused } + enum TokenType { text, control, userDefined, unknown, unused } /** The scoring strategy to use for picking segments */ public enum Scoring { @@ -48,6 +48,11 @@ public class SentencePieceEncoder implements Segmenter { private final Map models; + @Inject + public SentencePieceEncoder(SentencePieceConfig config) { + this(new Builder(config)); + } + public SentencePieceEncoder(Builder builder) { collapseUnknowns = builder.getCollapseUnknowns(); scoring = builder.getScoring(); @@ -56,6 +61,8 @@ public class SentencePieceEncoder implements Segmenter { .stream() .map(e -> new Model(e.getKey(), e.getValue())) .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); + if (models.isEmpty()) + throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured"); } /** @@ -250,6 +257,7 @@ public class SentencePieceEncoder implements Segmenter { private static final class Model { + final Path source; final Language language; final float minScore; final float maxScore; @@ -257,6 +265,7 @@ public class SentencePieceEncoder implements Segmenter { Model(Language language, Path path) { try { + this.source = path; this.language = language; var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile())); float minScore = Float.MAX_VALUE; @@ -271,10 +280,15 @@ public class SentencePieceEncoder implements Segmenter { this.maxScore = maxScore; } catch (IOException e) { - throw new IllegalArgumentException("Could not read a SentencePiece model from '" + path + "'", e); + throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e); } } + @Override + public String toString() { + return "SentencePiece model for " + language + ": '" + source + "'"; + } + } public static class Builder { @@ -283,6 +297,18 @@ public class SentencePieceEncoder implements Segmenter { private boolean collapseUnknowns = true; private Scoring scoring = Scoring.fewestSegments; + public Builder() { + } + + private Builder(SentencePieceConfig config) { + collapseUnknowns = config.collapseUnknowns(); + scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments + : Scoring.highestScore; + for (SentencePieceConfig.Model model : config.model()) { + addModel(Language.fromLanguageTag(model.language()), model.path()); + } + } + public void addModel(Language language, Path model) { models.put(language, model); } @@ -307,9 +333,7 @@ public class SentencePieceEncoder implements Segmenter { } public boolean getCollapseUnknowns() { return collapseUnknowns; } - /** - * Sets the scoring strategy to use when picking a segmentation. Default: fewestTokens. - */ + /** Sets the scoring strategy to use when picking a segmentation. Default: fewestSegments. */ public Builder setScoring(Scoring scoring) { this.scoring = scoring; return this; diff --git a/linguistics/src/main/resources/configdefinitions/sentence-piece.def b/linguistics/src/main/resources/configdefinitions/sentence-piece.def new file mode 100644 index 00000000000..b91c0c45dc4 --- /dev/null +++ b/linguistics/src/main/resources/configdefinitions/sentence-piece.def @@ -0,0 +1,18 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +# Configures com.yahoo.language.sentencepiece.SentencePieceEncoder + +namespace=language.sentencepiece + +# Whether consecutive unknown character should be collapsed into one large unknown token (default +# or be returned as single character tokens. +collapseUnknowns bool default=true + +# The scoring strategy to use when picking a segmentation. +scoring enum { highestScore, fewestSegments } default=fewestSegments + +# The language a model is for, one of the language tags in com.yahoo.language.Language. +# Use "unknown" for models to be used with any language. +model[].language string +# The path to the model relative to the application package root +model[].path path \ No newline at end of file diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java new file mode 100644 index 00000000000..edbbe21ec53 --- /dev/null +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java @@ -0,0 +1,59 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.sentencepiece; + +import com.yahoo.config.FileReference; +import com.yahoo.language.Language; +import org.junit.Test; + +/** + * @author bratseth + */ +public class SentencePieceConfigurationTest { + + @Test + public void testEnglishTokenization() { + var b = new SentencePieceConfig.Builder(); + addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); + tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); + } + + @Test + public void testNoCollapse() { + var b = new SentencePieceConfig.Builder(); + addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + b.collapseUnknowns(false); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); + } + + @Test + public void testHighestScore() { + var b = new SentencePieceConfig.Builder(); + addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + b.scoring(SentencePieceConfig.Scoring.highestScore); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented("hello", "▁h", "el", "lo"); + } + + @Test + public void testMultiLanguageTokenization() { + var b = new SentencePieceConfig.Builder(); + addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b); + addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); + tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); + tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); + } + + private void addModel(String language, String file, SentencePieceConfig.Builder b) { + var mb = new SentencePieceConfig.Model.Builder(); + mb.language(language); + mb.path(new FileReference(file)); + b.model(mb); + } + +} diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index 7d0c1c5c78e..f86bc2f716b 100644 --- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -6,10 +6,6 @@ import com.yahoo.language.Language; import org.junit.Test; import java.io.File; -import java.io.IOException; -import java.nio.file.Path; - -import static org.junit.Assert.assertArrayEquals; /** * @author bratseth @@ -61,37 +57,14 @@ public class SentencePieceTest { } @Test - public void testJapaneseTokenization() throws IOException { + public void testMultiLanguageTokenization() { SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder(); builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); var tester = new SentencePieceTester(builder); tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); - } - - private static class SentencePieceTester { - - private final SentencePieceEncoder encoder; - - public SentencePieceTester(Path model) { - this(new SentencePieceEncoder.Builder().addDefaultModel(model)); - } - - public SentencePieceTester(SentencePieceEncoder.Builder builder) { - encoder = builder.build(); - } - - private void assertEncoded(String input, Integer ... expectedCodes) { - assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray()); - } - - private void assertSegmented(String input, String ... expectedSegments) { - assertSegmented(Language.UNKNOWN, input, expectedSegments); - } - private void assertSegmented(Language language, String input, String ... expectedSegments) { - assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray()); - } - + tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); + tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); } } diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java new file mode 100644 index 00000000000..dee9be5aa7e --- /dev/null +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java @@ -0,0 +1,40 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// + +package com.yahoo.language.sentencepiece; + +import com.yahoo.language.Language; + +import java.nio.file.Path; + +import static org.junit.Assert.assertArrayEquals; + +class SentencePieceTester { + + private final SentencePieceEncoder encoder; + + public SentencePieceTester(Path model) { + this(new SentencePieceEncoder.Builder().addDefaultModel(model)); + } + + public SentencePieceTester(SentencePieceEncoder.Builder builder) { + this(builder.build()); + } + + public SentencePieceTester(SentencePieceEncoder encoder) { + this.encoder = encoder; + } + + public void assertEncoded(String input, Integer... expectedCodes) { + assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray()); + } + + public void assertSegmented(String input, String... expectedSegments) { + assertSegmented(Language.UNKNOWN, input, expectedSegments); + } + + public void assertSegmented(Language language, String input, String... expectedSegments) { + assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray()); + } + +} -- cgit v1.2.3