summaryrefslogtreecommitdiffstats
path: root/linguistics/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-16 11:04:56 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-16 11:04:56 +0200
commita2afdafbffcbc09594fd629c65746ec253f180be (patch)
tree987804d4362c801c6c47cfe469ec6b97409de6ef /linguistics/src/main
parent381033510b992049d55cae9964d942b4b47eb9df (diff)
Make SentencePieceEncoder configurable
Diffstat (limited to 'linguistics/src/main')
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java36
-rw-r--r--linguistics/src/main/resources/configdefinitions/sentence-piece.def18
2 files changed, 48 insertions, 6 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
index a755a9e6ff3..4bf808bec0c 100644
--- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
@@ -3,6 +3,7 @@
package com.yahoo.language.sentencepiece;
import com.google.common.annotations.Beta;
+import com.google.inject.Inject;
import com.yahoo.io.IOUtils;
import com.yahoo.language.Language;
import com.yahoo.language.process.Segmenter;
@@ -30,8 +31,7 @@ import java.util.stream.Collectors;
public class SentencePieceEncoder implements Segmenter {
// TODO: Support characters beyond BMP
-
- public enum TokenType { text, control, userDefined, unknown, unused }
+ enum TokenType { text, control, userDefined, unknown, unused }
/** The scoring strategy to use for picking segments */
public enum Scoring {
@@ -48,6 +48,11 @@ public class SentencePieceEncoder implements Segmenter {
private final Map<Language, Model> models;
+ @Inject
+ public SentencePieceEncoder(SentencePieceConfig config) {
+ this(new Builder(config));
+ }
+
public SentencePieceEncoder(Builder builder) {
collapseUnknowns = builder.getCollapseUnknowns();
scoring = builder.getScoring();
@@ -56,6 +61,8 @@ public class SentencePieceEncoder implements Segmenter {
.stream()
.map(e -> new Model(e.getKey(), e.getValue()))
.collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
+ if (models.isEmpty())
+ throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured");
}
/**
@@ -250,6 +257,7 @@ public class SentencePieceEncoder implements Segmenter {
private static final class Model {
+ final Path source;
final Language language;
final float minScore;
final float maxScore;
@@ -257,6 +265,7 @@ public class SentencePieceEncoder implements Segmenter {
Model(Language language, Path path) {
try {
+ this.source = path;
this.language = language;
var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile()));
float minScore = Float.MAX_VALUE;
@@ -271,10 +280,15 @@ public class SentencePieceEncoder implements Segmenter {
this.maxScore = maxScore;
}
catch (IOException e) {
- throw new IllegalArgumentException("Could not read a SentencePiece model from '" + path + "'", e);
+ throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e);
}
}
+ @Override
+ public String toString() {
+ return "SentencePiece model for " + language + ": '" + source + "'";
+ }
+
}
public static class Builder {
@@ -283,6 +297,18 @@ public class SentencePieceEncoder implements Segmenter {
private boolean collapseUnknowns = true;
private Scoring scoring = Scoring.fewestSegments;
+ public Builder() {
+ }
+
+ private Builder(SentencePieceConfig config) {
+ collapseUnknowns = config.collapseUnknowns();
+ scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments
+ : Scoring.highestScore;
+ for (SentencePieceConfig.Model model : config.model()) {
+ addModel(Language.fromLanguageTag(model.language()), model.path());
+ }
+ }
+
public void addModel(Language language, Path model) {
models.put(language, model);
}
@@ -307,9 +333,7 @@ public class SentencePieceEncoder implements Segmenter {
}
public boolean getCollapseUnknowns() { return collapseUnknowns; }
- /**
- * Sets the scoring strategy to use when picking a segmentation. Default: fewestTokens.
- */
+ /** Sets the scoring strategy to use when picking a segmentation. Default: fewestSegments. */
public Builder setScoring(Scoring scoring) {
this.scoring = scoring;
return this;
diff --git a/linguistics/src/main/resources/configdefinitions/sentence-piece.def b/linguistics/src/main/resources/configdefinitions/sentence-piece.def
new file mode 100644
index 00000000000..b91c0c45dc4
--- /dev/null
+++ b/linguistics/src/main/resources/configdefinitions/sentence-piece.def
@@ -0,0 +1,18 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+# Configures com.yahoo.language.sentencepiece.SentencePieceEncoder
+
+namespace=language.sentencepiece
+
+# Whether consecutive unknown character should be collapsed into one large unknown token (default
+# or be returned as single character tokens.
+collapseUnknowns bool default=true
+
+# The scoring strategy to use when picking a segmentation.
+scoring enum { highestScore, fewestSegments } default=fewestSegments
+
+# The language a model is for, one of the language tags in com.yahoo.language.Language.
+# Use "unknown" for models to be used with any language.
+model[].language string
+# The path to the model relative to the application package root
+model[].path path \ No newline at end of file