summaryrefslogtreecommitdiffstats
path: root/linguistics
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2021-09-17 08:01:29 +0200
committerGitHub <noreply@github.com>2021-09-17 08:01:29 +0200
commit4d31a75b8a249593d0a3503669d3399b980c8be1 (patch)
tree050c7f725f4efdb736912c405008bd1d49bd782c /linguistics
parentdaab62042f34575d545dcd0b6fd100e232848c85 (diff)
parenta0f2ddb8b759a928329996050c818f5a4fae90b0 (diff)
Merge pull request #19180 from vespa-engine/bratseth/encoder-interface
Bratseth/encoder interface
Diffstat (limited to 'linguistics')
-rw-r--r--linguistics/abi-spec.json66
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Encoder.java39
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Stemmer.java4
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java60
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java47
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java17
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java90
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java227
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java13
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java36
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java3
11 files changed, 374 insertions, 228 deletions
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 136d07721de..e8687b5c9f4 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -328,6 +328,20 @@
],
"fields": []
},
+ "com.yahoo.language.process.Encoder": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods": [
+ "public abstract java.util.List encode(java.lang.String, com.yahoo.language.Language)",
+ "public abstract com.yahoo.tensor.Tensor encode(java.lang.String, com.yahoo.language.Language, com.yahoo.tensor.TensorType)"
+ ],
+ "fields": []
+ },
"com.yahoo.language.process.GramSplitter$Gram": {
"superClass": "java.lang.Object",
"interfaces": [],
@@ -701,6 +715,23 @@
],
"fields": []
},
+ "com.yahoo.language.sentencepiece.Scoring": {
+ "superClass": "java.lang.Enum",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final",
+ "enum"
+ ],
+ "methods": [
+ "public static com.yahoo.language.sentencepiece.Scoring[] values()",
+ "public static com.yahoo.language.sentencepiece.Scoring valueOf(java.lang.String)"
+ ],
+ "fields": [
+ "public static final enum com.yahoo.language.sentencepiece.Scoring highestScore",
+ "public static final enum com.yahoo.language.sentencepiece.Scoring fewestSegments"
+ ]
+ },
"com.yahoo.language.sentencepiece.SentencePieceConfig$Builder": {
"superClass": "java.lang.Object",
"interfaces": [
@@ -846,33 +877,17 @@
"public java.util.Map getModels()",
"public com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder setCollapseUnknowns(boolean)",
"public boolean getCollapseUnknowns()",
- "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder setScoring(com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring)",
- "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring getScoring()",
+ "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder setScoring(com.yahoo.language.sentencepiece.Scoring)",
+ "public com.yahoo.language.sentencepiece.Scoring getScoring()",
"public com.yahoo.language.sentencepiece.SentencePieceEncoder build()"
],
"fields": []
},
- "com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring": {
- "superClass": "java.lang.Enum",
- "interfaces": [],
- "attributes": [
- "public",
- "final",
- "enum"
- ],
- "methods": [
- "public static com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring[] values()",
- "public static com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring valueOf(java.lang.String)"
- ],
- "fields": [
- "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring highestScore",
- "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring fewestSegments"
- ]
- },
"com.yahoo.language.sentencepiece.SentencePieceEncoder": {
"superClass": "java.lang.Object",
"interfaces": [
- "com.yahoo.language.process.Segmenter"
+ "com.yahoo.language.process.Segmenter",
+ "com.yahoo.language.process.Encoder"
],
"attributes": [
"public"
@@ -886,5 +901,16 @@
"public java.lang.String normalize(java.lang.String)"
],
"fields": []
+ },
+ "com.yahoo.language.sentencepiece.Trie": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>()"
+ ],
+ "fields": []
}
} \ No newline at end of file
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Encoder.java b/linguistics/src/main/java/com/yahoo/language/process/Encoder.java
new file mode 100644
index 00000000000..91de16f669b
--- /dev/null
+++ b/linguistics/src/main/java/com/yahoo/language/process/Encoder.java
@@ -0,0 +1,39 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.process;
+
+import com.yahoo.language.Language;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.List;
+
+/**
+ * An encoder converts a text string to a tensor or list of tokens
+ *
+ * @author bratseth
+ */
+public interface Encoder {
+
+ /**
+ * Encodes text into tokens in a list of ids.
+ *
+ * @param text the text to encode
+ * @param language the language of the text, or UNKNOWN to use language independent encoding
+ * @return the text encoded to a list of segment ids
+ * @throws IllegalArgumentException if the language is not supported by this encoder
+ */
+ List<Integer> encode(String text, Language language);
+
+ /**
+ * Encodes text into tokens in a tensor.
+ * The information contained in the encoding may depend on the tensor type.
+ *
+ * @param text the text to encode
+ * @param language the language of the text, or UNKNOWN to use language independent encoding
+ * @param tensorType the type of the ttensor to be returned
+ * @return the tex encoded into a tensor of the supplied type
+ * @throws IllegalArgumentException if the language or tensor type is not supported by this encoder
+ */
+ Tensor encode(String text, Language language, TensorType tensorType);
+
+}
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Stemmer.java b/linguistics/src/main/java/com/yahoo/language/process/Stemmer.java
index da8a73407ff..a2d0d0a84c9 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Stemmer.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Stemmer.java
@@ -6,9 +6,9 @@ import com.yahoo.language.Language;
import java.util.List;
/**
- * <p>Interface providing stemming of single words.</p>
+ * Interface providing stemming of single words.
*
- * @author <a href="mailto:mathiasm@yahoo-inc.com">Mathias Mølster Lidal</a>
+ * @author Mathias Mølster Lidal
*/
public interface Stemmer {
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java
new file mode 100644
index 00000000000..74f300057dc
--- /dev/null
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java
@@ -0,0 +1,60 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.io.IOUtils;
+import com.yahoo.language.Language;
+import sentencepiece.SentencepieceModel;
+
+import java.io.IOException;
+import java.nio.file.Path;
+
+/**
+ * A SentencePiece model
+ *
+ * @author bratseth
+ */
+final class Model {
+
+ final Path source;
+ final Language language;
+ final float minScore;
+ final float maxScore;
+ final Trie tokens = new Trie();
+
+ Model(Language language, Path path) {
+ try {
+ this.source = path;
+ this.language = language;
+ var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile()));
+ float minScore = Float.MAX_VALUE;
+ float maxScore = Float.MIN_VALUE;
+ for (int i = 0; i < sp.getPiecesCount(); i++) {
+ var piece = sp.getPieces(i);
+ tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore());
+ minScore = Math.min(piece.getScore(), minScore);
+ maxScore = Math.max(piece.getScore(), maxScore);
+ }
+ this.minScore = minScore;
+ this.maxScore = maxScore;
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e);
+ }
+ }
+
+ private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) {
+ switch (type) {
+ case USER_DEFINED : return TokenType.userDefined;
+ case UNKNOWN : return TokenType.unknown;
+ case NORMAL : return TokenType.text;
+ case CONTROL : return TokenType.control;
+ case UNUSED : return TokenType.unused;
+ default : throw new IllegalArgumentException("Unknkown token type " + type);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "SentencePiece model for " + language + ": '" + source + "'";
+ }
+
+}
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java
new file mode 100644
index 00000000000..2141505374c
--- /dev/null
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java
@@ -0,0 +1,47 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+/**
+ * Builds a result from a sentencepiece tokenization by being called for each segment in reverse
+ *
+ * @param <RESULTTYPE> the type of result this produces
+ * @author bratseth
+ */
+abstract class ResultBuilder<RESULTTYPE> {
+
+ private final RESULTTYPE result;
+
+ ResultBuilder(RESULTTYPE result) {
+ this.result = result;
+ }
+
+ /** Called for each segment, starting from the last and working backwards */
+ abstract void add(int start, int end, SentencePieceAlgorithm.SegmentEnd[] segmentEnds);
+
+ RESULTTYPE result() {return result;}
+
+ void build(String input, SentencePieceAlgorithm.SegmentEnd[] segmentEnds, boolean collapseUnknowns) {
+ if (collapseUnknowns) {
+ int segmentEnd = input.length();
+ int collapsedSegmentEnd = segmentEnd;
+ while (segmentEnd > 0) {
+ if (segmentEnds[segmentEnd].type != TokenType.unknown ) {
+ if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment
+ add(segmentEnd, collapsedSegmentEnd, segmentEnds);
+ }
+ add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds);
+ collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart;
+ }
+ segmentEnd = segmentEnds[segmentEnd].segmentStart;
+ }
+ }
+ else {
+ int segmentEnd = input.length();
+ while (segmentEnd > 0) {
+ add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds);
+ segmentEnd = segmentEnds[segmentEnd].segmentStart;
+ }
+ }
+ }
+
+}
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java
new file mode 100644
index 00000000000..6c8560abee7
--- /dev/null
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java
@@ -0,0 +1,17 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+/**
+ * The scoring strategy to use for picking segments
+ *
+ * @author bratseth
+ */
+public enum Scoring {
+
+ /** Find the segmentation that has the highest score */
+ highestScore,
+
+ /** Find the segmentation that has the fewest segments, resolve ties by score sum */
+ fewestSegments
+
+}
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java
new file mode 100644
index 00000000000..1659e3c0fa7
--- /dev/null
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java
@@ -0,0 +1,90 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+/**
+ * SentencePiece algorithm implementation
+ *
+ * @author bratseth
+ */
+class SentencePieceAlgorithm {
+
+ // TODO: Support characters beyond BMP
+
+ static final char spaceSymbol = '▁';
+
+ private final boolean collapseUnknowns;
+ private final Scoring scoring;
+
+ SentencePieceAlgorithm(boolean collapseUnknowns, Scoring scoring) {
+ this.collapseUnknowns = collapseUnknowns;
+ this.scoring = scoring;
+ }
+
+ public <RESULTTYPE> void segment(String input, ResultBuilder<RESULTTYPE> resultBuilder, Model model) {
+ SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1];
+ segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0);
+ int start = 0;
+ while (start < input.length()) { // segment from this position to the end of the text
+ Trie.Node node = model.tokens.root;
+ int characterPosition = start;
+ while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position
+ node = node.children.get(input.charAt(characterPosition++));
+ int length = characterPosition - start;
+ if (node != null && node.isToken() && node.type != TokenType.unused) {
+ float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score;
+ addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds);
+ }
+ else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable
+ addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds);
+ }
+ }
+ start++;
+ }
+ resultBuilder.build(input, segmentEnds, collapseUnknowns);
+ }
+
+ private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) {
+ if (segmentEnds[end] == null ||
+ segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) {
+ segmentEnds[end] = new SegmentEnd(type, id,
+ segmentEnds[start].pathScoreSum + score,
+ segmentEnds[start].pathSegmentCount + 1,
+ start);
+ }
+ }
+
+ final class SegmentEnd {
+
+ final TokenType type;
+ final int id;
+ final float pathScoreSum;
+ final int pathSegmentCount;
+ final int segmentStart;
+
+ SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) {
+ this.type = type;
+ this.id = id;
+ this.pathScoreSum = pathScoreSum;
+ this.pathSegmentCount = pathSegmentCount;
+ this.segmentStart = segmentStart;
+ }
+
+ public float score() {
+ switch (scoring) {
+ case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum;
+ case highestScore: return pathScoreSum;
+ default : throw new IllegalArgumentException("Unknown scoring " + scoring);
+ }
+ }
+
+ public float scoreWith(float additionalSegmentScore) {
+ switch (scoring) {
+ case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore );
+ case highestScore: return pathScoreSum + additionalSegmentScore;
+ default : throw new IllegalArgumentException("Unknown scoring " + scoring);
+ }
+ }
+
+ }
+
+}
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
index c7b131cc439..b6659ebeaa3 100644
--- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
@@ -1,18 +1,15 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
package com.yahoo.language.sentencepiece;
import com.google.common.annotations.Beta;
import com.google.inject.Inject;
-import com.yahoo.io.IOUtils;
import com.yahoo.language.Language;
+import com.yahoo.language.process.Encoder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import sentencepiece.SentencepieceModel;
-import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
@@ -30,34 +27,19 @@ import java.util.stream.Collectors;
* @author bratseth
*/
@Beta
-public class SentencePieceEncoder implements Segmenter {
-
- // TODO: Support characters beyond BMP
- enum TokenType { text, control, userDefined, unknown, unused }
-
- /** The scoring strategy to use for picking segments */
- public enum Scoring {
- /** Find the segmentation that has the highest score */
- highestScore,
- /** Find the segmentation that has the fewest segments, resolve ties by score sum */
- fewestSegments
- }
-
- private static final char spaceSymbol = '▁';
-
- private final boolean collapseUnknowns;
- private final Scoring scoring;
+public class SentencePieceEncoder implements Segmenter, Encoder {
private final Map<Language, Model> models;
+ private final SentencePieceAlgorithm algorithm;
+
@Inject
public SentencePieceEncoder(SentencePieceConfig config) {
this(new Builder(config));
}
public SentencePieceEncoder(Builder builder) {
- collapseUnknowns = builder.getCollapseUnknowns();
- scoring = builder.getScoring();
+ algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring());
models = builder.getModels().entrySet()
.stream()
@@ -78,7 +60,7 @@ public class SentencePieceEncoder implements Segmenter {
public List<String> segment(String rawInput, Language language) {
String input = normalize(rawInput);
var resultBuilder = new ResultBuilder<List<String>>(new ArrayList<>()) {
- public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) {
+ public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
result().add(input.substring(segmentStart, segmentEnd));
}
};
@@ -94,9 +76,10 @@ public class SentencePieceEncoder implements Segmenter {
* @param language the model to use, or Language.UNKNOWN to use the default model if any
* @return the list of zero or more token ids resulting from segmenting the input text
*/
+ @Override
public List<Integer> encode(String rawInput, Language language) {
var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) {
- public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) {
+ public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
result().add(segmentEnds[segmentEnd].id);
}
};
@@ -106,8 +89,18 @@ public class SentencePieceEncoder implements Segmenter {
}
/**
- * Encodes directly to a tensor.
+ * <p>Encodes directly to a tensor.</p>
+ *
+ * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order
+ * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small
+ * it will be truncated.</p>
+ *
+ * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token
+ * position as value.</p>
+ *
+ * <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
*/
+ @Override
public Tensor encode(String rawInput, Language language, TensorType type) {
if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
// Build to a list first since we can't reverse a tensor builder
@@ -136,29 +129,10 @@ public class SentencePieceEncoder implements Segmenter {
}
}
- private <RESULTTYPE> void segment(String input, Language language, ResultBuilder<RESULTTYPE> resultBuilder) {
+ private <RESULTTYPE> void segment(String input, Language language,
+ ResultBuilder<RESULTTYPE> resultBuilder) {
Model model = resolveFrom(language);
-
- SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1];
- segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0);
- int start = 0;
- while (start < input.length()) { // segment from this position to the end of the text
- Trie.Node node = model.tokens.root;
- int characterPosition = start;
- while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position
- node = node.children.get(input.charAt(characterPosition++));
- int length = characterPosition - start;
- if (node != null && node.isToken() && node.type != TokenType.unused) {
- float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score;
- addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds);
- }
- else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable
- addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds);
- }
- }
- start++;
- }
- createResult(input, segmentEnds, resultBuilder);
+ algorithm.segment(input, resultBuilder, model);
}
private Model resolveFrom(Language language) {
@@ -168,88 +142,6 @@ public class SentencePieceEncoder implements Segmenter {
throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured");
}
- private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) {
- if (segmentEnds[end] == null ||
- segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) {
- segmentEnds[end] = new SegmentEnd(type, id,
- segmentEnds[start].pathScoreSum + score,
- segmentEnds[start].pathSegmentCount + 1,
- start);
- }
- }
-
- private <RESULTTYPE> void createResult(String input, SegmentEnd[] segmentEnds, ResultBuilder<RESULTTYPE> resultBuilder) {
- if (collapseUnknowns) {
- int segmentEnd = input.length();
- int collapsedSegmentEnd = segmentEnd;
- while (segmentEnd > 0) {
- if (segmentEnds[segmentEnd].type != TokenType.unknown ) {
- if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment
- resultBuilder.add(segmentEnd, collapsedSegmentEnd, segmentEnds);
- }
- resultBuilder.add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds);
- collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart;
- }
- segmentEnd = segmentEnds[segmentEnd].segmentStart;
- }
- }
- else {
- int segmentEnd = input.length();
- while (segmentEnd > 0) {
- resultBuilder.add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds);
- segmentEnd = segmentEnds[segmentEnd].segmentStart;
- }
- }
- }
-
- private static abstract class ResultBuilder<RESULTTYPE> {
-
- private final RESULTTYPE result;
-
- ResultBuilder(RESULTTYPE result) {
- this.result = result;
- }
-
- abstract void add(int start, int end, SegmentEnd[] segmentEnds);
-
- RESULTTYPE result() { return result; }
-
- }
-
- private final class SegmentEnd {
-
- final TokenType type;
- final int id;
- final float pathScoreSum;
- final int pathSegmentCount;
- final int segmentStart;
-
- SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) {
- this.type = type;
- this.id = id;
- this.pathScoreSum = pathScoreSum;
- this.pathSegmentCount = pathSegmentCount;
- this.segmentStart = segmentStart;
- }
-
- public float score() {
- switch (scoring) {
- case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum;
- case highestScore: return pathScoreSum;
- default : throw new IllegalArgumentException("Unknown scoring " + scoring);
- }
- }
-
- public float scoreWith(float additionalSegmentScore) {
- switch (scoring) {
- case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore );
- case highestScore: return pathScoreSum + additionalSegmentScore;
- default : throw new IllegalArgumentException("Unknown scoring " + scoring);
- }
- }
-
- }
-
public String normalize(String s) {
StringBuilder b = new StringBuilder(s.length() + 1);
boolean queuedSpace = true; // Always start by one space
@@ -260,7 +152,7 @@ public class SentencePieceEncoder implements Segmenter {
}
else {
if (queuedSpace) {
- b.append(spaceSymbol);
+ b.append(SentencePieceAlgorithm.spaceSymbol);
queuedSpace = false;
}
b.append(c);
@@ -269,79 +161,6 @@ public class SentencePieceEncoder implements Segmenter {
return b.toString();
}
- private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) {
- switch (type) {
- case USER_DEFINED : return TokenType.userDefined;
- case UNKNOWN : return TokenType.unknown;
- case NORMAL : return TokenType.text;
- case CONTROL : return TokenType.control;
- case UNUSED : return TokenType.unused;
- default : throw new IllegalArgumentException("Unknkown token type " + type);
- }
- }
-
- private static class Trie {
-
- final Node root = new Node();
-
- void add(TokenType type, int id, String word, float score) {
- Node current = root;
- for (char l : word.toCharArray())
- current = current.children.computeIfAbsent(l, c -> new Node());
- current.type = type;
- current.id = id;
- current.score = score;
- }
-
- static class Node {
-
- Integer id;
- TokenType type;
- Float score;
- private final Map<Character, Node> children = new HashMap<>();
-
- boolean isToken() { return type != null; }
-
- }
-
- }
-
- private static final class Model {
-
- final Path source;
- final Language language;
- final float minScore;
- final float maxScore;
- final Trie tokens = new Trie();
-
- Model(Language language, Path path) {
- try {
- this.source = path;
- this.language = language;
- var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile()));
- float minScore = Float.MAX_VALUE;
- float maxScore = Float.MIN_VALUE;
- for (int i = 0; i < sp.getPiecesCount(); i++) {
- var piece = sp.getPieces(i);
- tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore());
- minScore = Math.min(piece.getScore(), minScore);
- maxScore = Math.max(piece.getScore(), maxScore);
- }
- this.minScore = minScore;
- this.maxScore = maxScore;
- }
- catch (IOException e) {
- throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e);
- }
- }
-
- @Override
- public String toString() {
- return "SentencePiece model for " + language + ": '" + source + "'";
- }
-
- }
-
public static class Builder {
private final Map<Language, Path> models = new HashMap<>();
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java
new file mode 100644
index 00000000000..782030a8e4d
--- /dev/null
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java
@@ -0,0 +1,13 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+/**
+ * SentencePiece token types
+ *
+ * @author bratseth
+ */
+enum TokenType {
+
+ text, control, userDefined, unknown, unused
+
+}
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java
new file mode 100644
index 00000000000..f3287a49517
--- /dev/null
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java
@@ -0,0 +1,36 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A simple trie for sentencepiece token lookup
+ *
+ * @author bratseth
+ */
+public class Trie {
+
+ final Node root = new Node();
+
+ void add(TokenType type, int id, String word, float score) {
+ Node current = root;
+ for (char l : word.toCharArray())
+ current = current.children.computeIfAbsent(l, c -> new Node());
+ current.type = type;
+ current.id = id;
+ current.score = score;
+ }
+
+ static class Node {
+
+ Integer id;
+ TokenType type;
+ Float score;
+ final Map<Character, Node> children = new HashMap<>();
+
+ boolean isToken() { return type != null; }
+
+ }
+
+}
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index 5b77324a6fc..d60d7386d4b 100644
--- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -3,7 +3,6 @@
package com.yahoo.language.sentencepiece;
import com.yahoo.language.Language;
-import com.yahoo.tensor.Tensor;
import org.junit.Test;
import java.io.File;
@@ -69,7 +68,7 @@ public class SentencePieceTest {
public void testHighestScore() {
var tester = new SentencePieceTester(new SentencePieceEncoder.Builder()
.addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
- .setScoring(SentencePieceEncoder.Scoring.highestScore));
+ .setScoring(Scoring.highestScore));
tester.assertSegmented("h", "▁h");
tester.assertSegmented("he", "▁he");
tester.assertSegmented("hel", "▁h", "el");