diff options
Diffstat (limited to 'linguistics/src/main')
-rw-r--r-- | linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java | 36 | ||||
-rw-r--r-- | linguistics/src/main/resources/configdefinitions/sentence-piece.def | 18 |
2 files changed, 48 insertions, 6 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java index a755a9e6ff3..4bf808bec0c 100644 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -3,6 +3,7 @@ package com.yahoo.language.sentencepiece; import com.google.common.annotations.Beta; +import com.google.inject.Inject; import com.yahoo.io.IOUtils; import com.yahoo.language.Language; import com.yahoo.language.process.Segmenter; @@ -30,8 +31,7 @@ import java.util.stream.Collectors; public class SentencePieceEncoder implements Segmenter { // TODO: Support characters beyond BMP - - public enum TokenType { text, control, userDefined, unknown, unused } + enum TokenType { text, control, userDefined, unknown, unused } /** The scoring strategy to use for picking segments */ public enum Scoring { @@ -48,6 +48,11 @@ public class SentencePieceEncoder implements Segmenter { private final Map<Language, Model> models; + @Inject + public SentencePieceEncoder(SentencePieceConfig config) { + this(new Builder(config)); + } + public SentencePieceEncoder(Builder builder) { collapseUnknowns = builder.getCollapseUnknowns(); scoring = builder.getScoring(); @@ -56,6 +61,8 @@ public class SentencePieceEncoder implements Segmenter { .stream() .map(e -> new Model(e.getKey(), e.getValue())) .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); + if (models.isEmpty()) + throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured"); } /** @@ -250,6 +257,7 @@ public class SentencePieceEncoder implements Segmenter { private static final class Model { + final Path source; final Language language; final float minScore; final float maxScore; @@ -257,6 +265,7 @@ public class SentencePieceEncoder implements Segmenter { Model(Language language, Path path) { try { + this.source = path; this.language = language; var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile())); float minScore = Float.MAX_VALUE; @@ -271,10 +280,15 @@ public class SentencePieceEncoder implements Segmenter { this.maxScore = maxScore; } catch (IOException e) { - throw new IllegalArgumentException("Could not read a SentencePiece model from '" + path + "'", e); + throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e); } } + @Override + public String toString() { + return "SentencePiece model for " + language + ": '" + source + "'"; + } + } public static class Builder { @@ -283,6 +297,18 @@ public class SentencePieceEncoder implements Segmenter { private boolean collapseUnknowns = true; private Scoring scoring = Scoring.fewestSegments; + public Builder() { + } + + private Builder(SentencePieceConfig config) { + collapseUnknowns = config.collapseUnknowns(); + scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments + : Scoring.highestScore; + for (SentencePieceConfig.Model model : config.model()) { + addModel(Language.fromLanguageTag(model.language()), model.path()); + } + } + public void addModel(Language language, Path model) { models.put(language, model); } @@ -307,9 +333,7 @@ public class SentencePieceEncoder implements Segmenter { } public boolean getCollapseUnknowns() { return collapseUnknowns; } - /** - * Sets the scoring strategy to use when picking a segmentation. Default: fewestTokens. - */ + /** Sets the scoring strategy to use when picking a segmentation. Default: fewestSegments. */ public Builder setScoring(Scoring scoring) { this.scoring = scoring; return this; diff --git a/linguistics/src/main/resources/configdefinitions/sentence-piece.def b/linguistics/src/main/resources/configdefinitions/sentence-piece.def new file mode 100644 index 00000000000..b91c0c45dc4 --- /dev/null +++ b/linguistics/src/main/resources/configdefinitions/sentence-piece.def @@ -0,0 +1,18 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +# Configures com.yahoo.language.sentencepiece.SentencePieceEncoder + +namespace=language.sentencepiece + +# Whether consecutive unknown character should be collapsed into one large unknown token (default +# or be returned as single character tokens. +collapseUnknowns bool default=true + +# The scoring strategy to use when picking a segmentation. +scoring enum { highestScore, fewestSegments } default=fewestSegments + +# The language a model is for, one of the language tags in com.yahoo.language.Language. +# Use "unknown" for models to be used with any language. +model[].language string +# The path to the model relative to the application package root +model[].path path
\ No newline at end of file |