aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics-components/src
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java60
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java47
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java17
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java90
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java220
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java13
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java36
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java7
-rw-r--r--linguistics-components/src/main/protobuf/sentencepiece_model.proto310
-rw-r--r--linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def18
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java59
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java89
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java49
-rw-r--r--linguistics-components/src/test/models/sentencepiece/en.wiki.bpe.vs10000.modelbin0 -> 400869 bytes
-rw-r--r--linguistics-components/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.modelbin0 -> 300865 bytes
15 files changed, 1015 insertions, 0 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
new file mode 100644
index 00000000000..74f300057dc
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
@@ -0,0 +1,60 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.io.IOUtils;
+import com.yahoo.language.Language;
+import sentencepiece.SentencepieceModel;
+
+import java.io.IOException;
+import java.nio.file.Path;
+
+/**
+ * A SentencePiece model
+ *
+ * @author bratseth
+ */
+final class Model {
+
+ final Path source;
+ final Language language;
+ final float minScore;
+ final float maxScore;
+ final Trie tokens = new Trie();
+
+ Model(Language language, Path path) {
+ try {
+ this.source = path;
+ this.language = language;
+ var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile()));
+ float minScore = Float.MAX_VALUE;
+ float maxScore = Float.MIN_VALUE;
+ for (int i = 0; i < sp.getPiecesCount(); i++) {
+ var piece = sp.getPieces(i);
+ tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore());
+ minScore = Math.min(piece.getScore(), minScore);
+ maxScore = Math.max(piece.getScore(), maxScore);
+ }
+ this.minScore = minScore;
+ this.maxScore = maxScore;
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e);
+ }
+ }
+
+ private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) {
+ switch (type) {
+ case USER_DEFINED : return TokenType.userDefined;
+ case UNKNOWN : return TokenType.unknown;
+ case NORMAL : return TokenType.text;
+ case CONTROL : return TokenType.control;
+ case UNUSED : return TokenType.unused;
+ default : throw new IllegalArgumentException("Unknkown token type " + type);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "SentencePiece model for " + language + ": '" + source + "'";
+ }
+
+}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java
new file mode 100644
index 00000000000..2141505374c
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java
@@ -0,0 +1,47 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+/**
+ * Builds a result from a sentencepiece tokenization by being called for each segment in reverse
+ *
+ * @param <RESULTTYPE> the type of result this produces
+ * @author bratseth
+ */
+abstract class ResultBuilder<RESULTTYPE> {
+
+ private final RESULTTYPE result;
+
+ ResultBuilder(RESULTTYPE result) {
+ this.result = result;
+ }
+
+ /** Called for each segment, starting from the last and working backwards */
+ abstract void add(int start, int end, SentencePieceAlgorithm.SegmentEnd[] segmentEnds);
+
+ RESULTTYPE result() {return result;}
+
+ void build(String input, SentencePieceAlgorithm.SegmentEnd[] segmentEnds, boolean collapseUnknowns) {
+ if (collapseUnknowns) {
+ int segmentEnd = input.length();
+ int collapsedSegmentEnd = segmentEnd;
+ while (segmentEnd > 0) {
+ if (segmentEnds[segmentEnd].type != TokenType.unknown ) {
+ if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment
+ add(segmentEnd, collapsedSegmentEnd, segmentEnds);
+ }
+ add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds);
+ collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart;
+ }
+ segmentEnd = segmentEnds[segmentEnd].segmentStart;
+ }
+ }
+ else {
+ int segmentEnd = input.length();
+ while (segmentEnd > 0) {
+ add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds);
+ segmentEnd = segmentEnds[segmentEnd].segmentStart;
+ }
+ }
+ }
+
+}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java
new file mode 100644
index 00000000000..6c8560abee7
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java
@@ -0,0 +1,17 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+/**
+ * The scoring strategy to use for picking segments
+ *
+ * @author bratseth
+ */
+public enum Scoring {
+
+ /** Find the segmentation that has the highest score */
+ highestScore,
+
+ /** Find the segmentation that has the fewest segments, resolve ties by score sum */
+ fewestSegments
+
+}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java
new file mode 100644
index 00000000000..1659e3c0fa7
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java
@@ -0,0 +1,90 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+/**
+ * SentencePiece algorithm implementation
+ *
+ * @author bratseth
+ */
+class SentencePieceAlgorithm {
+
+ // TODO: Support characters beyond BMP
+
+ static final char spaceSymbol = '▁';
+
+ private final boolean collapseUnknowns;
+ private final Scoring scoring;
+
+ SentencePieceAlgorithm(boolean collapseUnknowns, Scoring scoring) {
+ this.collapseUnknowns = collapseUnknowns;
+ this.scoring = scoring;
+ }
+
+ public <RESULTTYPE> void segment(String input, ResultBuilder<RESULTTYPE> resultBuilder, Model model) {
+ SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1];
+ segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0);
+ int start = 0;
+ while (start < input.length()) { // segment from this position to the end of the text
+ Trie.Node node = model.tokens.root;
+ int characterPosition = start;
+ while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position
+ node = node.children.get(input.charAt(characterPosition++));
+ int length = characterPosition - start;
+ if (node != null && node.isToken() && node.type != TokenType.unused) {
+ float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score;
+ addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds);
+ }
+ else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable
+ addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds);
+ }
+ }
+ start++;
+ }
+ resultBuilder.build(input, segmentEnds, collapseUnknowns);
+ }
+
+ private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) {
+ if (segmentEnds[end] == null ||
+ segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) {
+ segmentEnds[end] = new SegmentEnd(type, id,
+ segmentEnds[start].pathScoreSum + score,
+ segmentEnds[start].pathSegmentCount + 1,
+ start);
+ }
+ }
+
+ final class SegmentEnd {
+
+ final TokenType type;
+ final int id;
+ final float pathScoreSum;
+ final int pathSegmentCount;
+ final int segmentStart;
+
+ SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) {
+ this.type = type;
+ this.id = id;
+ this.pathScoreSum = pathScoreSum;
+ this.pathSegmentCount = pathSegmentCount;
+ this.segmentStart = segmentStart;
+ }
+
+ public float score() {
+ switch (scoring) {
+ case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum;
+ case highestScore: return pathScoreSum;
+ default : throw new IllegalArgumentException("Unknown scoring " + scoring);
+ }
+ }
+
+ public float scoreWith(float additionalSegmentScore) {
+ switch (scoring) {
+ case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore );
+ case highestScore: return pathScoreSum + additionalSegmentScore;
+ default : throw new IllegalArgumentException("Unknown scoring " + scoring);
+ }
+ }
+
+ }
+
+}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
new file mode 100644
index 00000000000..b6659ebeaa3
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
@@ -0,0 +1,220 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+import com.google.common.annotations.Beta;
+import com.google.inject.Inject;
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Encoder;
+import com.yahoo.language.process.Segmenter;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Integration with https://github.com/google/sentencepiece
+ * through http://docs.djl.ai/extensions/sentencepiece/index.html
+ *
+ * SentencePiece is a language-agnostic tokenizer for neural nets.
+ *
+ * @author bratseth
+ */
+@Beta
+public class SentencePieceEncoder implements Segmenter, Encoder {
+
+ private final Map<Language, Model> models;
+
+ private final SentencePieceAlgorithm algorithm;
+
+ @Inject
+ public SentencePieceEncoder(SentencePieceConfig config) {
+ this(new Builder(config));
+ }
+
+ public SentencePieceEncoder(Builder builder) {
+ algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring());
+
+ models = builder.getModels().entrySet()
+ .stream()
+ .map(e -> new Model(e.getKey(), e.getValue()))
+ .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
+ if (models.isEmpty())
+ throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured");
+ }
+
+ /**
+ * Segments the given text into token segments using the SentencePiece algorithm
+ *
+ * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
+ * @param language the model to use, or Language.UNKNOWN to use the default model if any
+ * @return the list of zero or more tokens resulting from segmenting the input text
+ */
+ @Override
+ public List<String> segment(String rawInput, Language language) {
+ String input = normalize(rawInput);
+ var resultBuilder = new ResultBuilder<List<String>>(new ArrayList<>()) {
+ public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
+ result().add(input.substring(segmentStart, segmentEnd));
+ }
+ };
+ segment(input, language, resultBuilder);
+ Collections.reverse(resultBuilder.result());
+ return resultBuilder.result();
+ }
+
+ /**
+ * Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids.
+ *
+ * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported.
+ * @param language the model to use, or Language.UNKNOWN to use the default model if any
+ * @return the list of zero or more token ids resulting from segmenting the input text
+ */
+ @Override
+ public List<Integer> encode(String rawInput, Language language) {
+ var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) {
+ public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
+ result().add(segmentEnds[segmentEnd].id);
+ }
+ };
+ segment(normalize(rawInput), language, resultBuilder);
+ Collections.reverse(resultBuilder.result());
+ return resultBuilder.result();
+ }
+
+ /**
+ * <p>Encodes directly to a tensor.</p>
+ *
+ * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order
+ * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small
+ * it will be truncated.</p>
+ *
+ * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token
+ * position as value.</p>
+ *
+ * <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
+ */
+ @Override
+ public Tensor encode(String rawInput, Language language, TensorType type) {
+ if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
+ // Build to a list first since we can't reverse a tensor builder
+ List<Integer> values = encode(rawInput, language);
+
+ long maxSize = values.size();
+ if (type.dimensions().get(0).size().isPresent())
+ maxSize = Math.min(maxSize, type.dimensions().get(0).size().get());
+
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ for (int i = 0; i < maxSize; i++)
+ builder.cell(values.get(i), i);
+ return builder.build();
+ }
+ else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) {
+ // Build to a list first since we can't reverse a tensor builder
+ List<String> values = segment(rawInput, language);
+
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ for (int i = 0; i < values.size(); i++)
+ builder.cell(TensorAddress.ofLabels(values.get(i)), i);
+ return builder.build();
+ }
+ else {
+ throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type);
+ }
+ }
+
+ private <RESULTTYPE> void segment(String input, Language language,
+ ResultBuilder<RESULTTYPE> resultBuilder) {
+ Model model = resolveFrom(language);
+ algorithm.segment(input, resultBuilder, model);
+ }
+
+ private Model resolveFrom(Language language) {
+ // Disregard language if there is default model
+ if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
+ if (models.containsKey(language)) return models.get(language);
+ throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured");
+ }
+
+ public String normalize(String s) {
+ StringBuilder b = new StringBuilder(s.length() + 1);
+ boolean queuedSpace = true; // Always start by one space
+ for (int i = 0; i < s.length(); i++) {
+ char c = s.charAt(i);
+ if (s.charAt(i) == ' ') {
+ queuedSpace = true;
+ }
+ else {
+ if (queuedSpace) {
+ b.append(SentencePieceAlgorithm.spaceSymbol);
+ queuedSpace = false;
+ }
+ b.append(c);
+ }
+ }
+ return b.toString();
+ }
+
+ public static class Builder {
+
+ private final Map<Language, Path> models = new HashMap<>();
+ private boolean collapseUnknowns = true;
+ private Scoring scoring = Scoring.fewestSegments;
+
+ public Builder() {
+ }
+
+ private Builder(SentencePieceConfig config) {
+ collapseUnknowns = config.collapseUnknowns();
+ scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments
+ : Scoring.highestScore;
+ for (SentencePieceConfig.Model model : config.model()) {
+ addModel(Language.fromLanguageTag(model.language()), model.path());
+ }
+ }
+
+ public void addModel(Language language, Path model) {
+ models.put(language, model);
+ }
+
+ /**
+ * Adds the model that will be used if the language is unknown, OR only one model is specified.
+ * The same as addModel(Language.UNKNOWN, model).
+ */
+ public Builder addDefaultModel(Path model) {
+ addModel(Language.UNKNOWN, model);
+ return this;
+ }
+ public Map<Language, Path> getModels() { return models; }
+
+ /**
+ * Sets whether consecutive unknown character should be collapsed into one large unknown token (default)
+ * or be returned as single character tokens.
+ */
+ public Builder setCollapseUnknowns(boolean collapseUnknowns) {
+ this.collapseUnknowns = collapseUnknowns;
+ return this;
+ }
+ public boolean getCollapseUnknowns() { return collapseUnknowns; }
+
+ /** Sets the scoring strategy to use when picking a segmentation. Default: fewestSegments. */
+ public Builder setScoring(Scoring scoring) {
+ this.scoring = scoring;
+ return this;
+ }
+ public Scoring getScoring() { return scoring; }
+
+ public SentencePieceEncoder build() {
+ if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied");
+ return new SentencePieceEncoder(this);
+ }
+
+ }
+
+}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java
new file mode 100644
index 00000000000..782030a8e4d
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java
@@ -0,0 +1,13 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+/**
+ * SentencePiece token types
+ *
+ * @author bratseth
+ */
+enum TokenType {
+
+ text, control, userDefined, unknown, unused
+
+}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java
new file mode 100644
index 00000000000..8e7c2db2ed3
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java
@@ -0,0 +1,36 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.language.sentencepiece;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A simple trie for sentencepiece token lookups.
+ *
+ * @author bratseth
+ */
+class Trie {
+
+ final Node root = new Node();
+
+ void add(TokenType type, int id, String word, float score) {
+ Node current = root;
+ for (char l : word.toCharArray())
+ current = current.children.computeIfAbsent(l, c -> new Node());
+ current.type = type;
+ current.id = id;
+ current.score = score;
+ }
+
+ static class Node {
+
+ Integer id;
+ TokenType type;
+ Float score;
+ final Map<Character, Node> children = new HashMap<>();
+
+ boolean isToken() { return type != null; }
+
+ }
+
+}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java
new file mode 100644
index 00000000000..3f97277c489
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java
@@ -0,0 +1,7 @@
+// Copyright 2021 Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/linguistics-components/src/main/protobuf/sentencepiece_model.proto b/linguistics-components/src/main/protobuf/sentencepiece_model.proto
new file mode 100644
index 00000000000..39626aede53
--- /dev/null
+++ b/linguistics-components/src/main/protobuf/sentencepiece_model.proto
@@ -0,0 +1,310 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.!
+
+syntax = "proto2";
+
+// TODO(taku): Needs to use LITE RUNTIME in OSS release.
+option optimize_for = LITE_RUNTIME;
+
+package sentencepiece;
+
+// TrainerSpec encodes a various parameters for SentencePiece training.
+message TrainerSpec {
+ ///////////////////////////////////////////////////////////////////
+ // General parameters
+ //
+ // Input corpus files.
+ // Trainer accepts the following two formats:
+ // A) Monolingual: plain text, one sentence per line.
+ // B) Bilingual: TSV, source sentence <tab> target sentence
+ // When bilingual data is passed, shared vocabulary model is built.
+ // Note that the input file must be raw corpus, not a preprocessed corpus.
+ // Trainer only loads the first `input_sentence_size` sentences specified
+ // with this parameter.
+ repeated string input = 1;
+
+ // Input corpus format:
+ // "text": one-sentence-per-line text format (default)
+ // "tsv": sentence <tab> freq
+ optional string input_format = 7;
+
+ // Output model file prefix.
+ // <model_prefix>.model and <model_prefix>.vocab are generated.
+ optional string model_prefix = 2;
+
+ // Model type. only have UNIGRAM now.
+ enum ModelType {
+ UNIGRAM = 1; // Unigram language model with dynamic algorithm
+ BPE = 2; // Byte Pair Encoding
+ WORD = 3; // Delimitered by whitespace.
+ CHAR = 4; // tokenizes into character sequence
+ }
+ optional ModelType model_type = 3 [default = UNIGRAM];
+
+ // Vocabulary size. 8k is the default size.
+ optional int32 vocab_size = 4 [default = 8000];
+
+ // List of the languages this model can accept.
+ // Since the model is language-agnostic, this field is used as a reference.
+ repeated string accept_language = 5;
+
+ // Size of self-test samples, which are encoded in the model file.
+ optional int32 self_test_sample_size = 6 [default = 0];
+
+ ///////////////////////////////////////////////////////////////////
+ // Training parameters.
+ //
+ // Uses characters which cover the corpus with the ratio of `chars_coverage`.
+ // This parameter determines the set of basic Alphabet of sentence piece.
+ // 1.0 - `chars_coverage` characters are treated as UNK.
+ // See also required_chars field.
+ optional float character_coverage = 10 [default = 0.9995];
+
+ // Maximum size of sentences the trainer loads from `input` parameter.
+ // Trainer simply loads the `input` files in sequence.
+ // It is better to shuffle the input corpus randomly.
+ optional uint64 input_sentence_size = 11 [default = 0];
+ optional bool shuffle_input_sentence = 19 [default = true];
+
+ // Maximum size of sentences to make seed sentence pieces.
+ // Extended suffix array is constructed to extract frequent
+ // sub-strings from the corpus. This uses 20N working space,
+ // where N is the size of corpus.
+ optional int32 mining_sentence_size = 12 [deprecated = true];
+
+ // Maximum size of sentences to train sentence pieces.
+ optional int32 training_sentence_size = 13 [deprecated = true];
+
+ // The size of seed sentencepieces.
+ // `seed_sentencepiece_size` must be larger than `vocab_size`.
+ optional int32 seed_sentencepiece_size = 14 [default = 1000000];
+
+ // In every EM sub-iterations, keeps top
+ // `shrinking_factor` * `current sentencepieces size` with respect to
+ // the loss of the sentence piece. This value should be smaller than 1.0.
+ optional float shrinking_factor = 15 [default = 0.75];
+
+ // The maximum sentence length in byte. The sentences with the length
+ // larger than `max_sentence_length` is simply ignored.
+ // Longer input tends to bring the following risks:
+ // * Overflow during EM training (unigram language model only)
+ // * Performance drop because of O(n log n) cost in BPE.
+ optional int32 max_sentence_length = 18 [default = 4192];
+
+ // Number of threads in the training.
+ optional int32 num_threads = 16 [default = 16];
+
+ // Number of EM sub iterations.
+ optional int32 num_sub_iterations = 17 [default = 2];
+
+ ///////////////////////////////////////////////////////////////////
+ // SentencePiece parameters which control the shapes of sentence piece.
+ //
+ // Maximum length of sentencepiece.
+ optional int32 max_sentencepiece_length = 20 [default = 16];
+
+ // Uses Unicode script to split sentence pieces.
+ // When `split_by_unicode_script` is true, we do not allow sentence piece to
+ // include multiple Unicode scripts, e.g. "F1" is not a valid piece.
+ // Exception: CJ characters (Hiragana/Katakana/Han) are all handled
+ // as one script type, since Japanese word can consist of multiple scripts.
+ // This exception is always applied regardless of the accept-language
+ // parameter.
+ optional bool split_by_unicode_script = 21 [default = true];
+
+ // When `split_by_number` is true, put a boundary between number and
+ // non-number transition. If we want to treat "F1" is one token, set this flag
+ // to be false.
+ optional bool split_by_number = 23 [default = true];
+
+ // Use a white space to split sentence pieces.
+ // When `split_by_whitespace` is false, we may have the piece containing
+ // a white space in the middle. e.g., "in_the".
+ optional bool split_by_whitespace = 22 [default = true];
+
+ // Adds whitespace symbol (_) as a suffix instead of prefix. e.g., _hello =>
+ // hello_. When `treat_whitespace_as_suffix` is true,
+ // NormalizerSpec::add_dummy_prefix will add the dummy whitespace to the end
+ // of sentence.
+ optional bool treat_whitespace_as_suffix = 24 [default = false];
+
+ // Allows pieces that only contain whitespaces instead of appearing only as
+ // prefix or suffix of other pieces.
+ optional bool allow_whitespace_only_pieces = 26 [default = false];
+
+ // Split all digits (0-9) into separate pieces.
+ optional bool split_digits = 25 [default = false];
+
+ ///////////////////////////////////////////////////////////////////
+ // Vocabulary management
+ //
+ // Defines control symbols used as an indicator to
+ // change the behavior of the decoder. <s> and </s> are pre-defined.
+ // We can use this field to encode various meta information,
+ // including language indicator in multilingual model.
+ // These symbols are not visible to users, but visible to
+ // the decoder. Note that when the input sentence contains control symbols,
+ // they are not treated as one token, but segmented into normal pieces.
+ // Control symbols must be inserted independently from the segmentation.
+ repeated string control_symbols = 30;
+
+ // Defines user defined symbols.
+ // These symbols are added with extremely high score
+ // so they are always treated as one unique symbol in any context.
+ // Typical usage of user_defined_symbols is placeholder for named entities.
+ repeated string user_defined_symbols = 31;
+
+ // Defines required characters. Each UTF8 character in this string is included
+ // in the character set regardless of character_coverage value. Unlike
+ // user_defined_symbols, these characters have scores based on the frequency
+ // on input sentences, and the model can form subwords using characters
+ // in this field.
+ optional string required_chars = 36;
+
+ // Decomposes unknown pieces into UTF-8 bytes.
+ optional bool byte_fallback = 35 [default = false];
+
+ // When creating the vocabulary file, defines whether or not to additionally
+ // output the score for each piece.
+ optional bool vocabulary_output_piece_score = 32 [default = true];
+
+ // `vocab_size` is treated as hard limit. Crash if
+ // the model can not produce the vocab of size `vocab_size`,
+ // When `hard_vocab_limit` is false, vocab_size is treated
+ // as soft limit. Note that when model_type=char,
+ // always assumes hard_vocab_limit = false.
+ optional bool hard_vocab_limit = 33 [default = true];
+
+ // use all symbols for vocab extraction. This flag is valid
+ // if model type is either CHAR or WORD
+ optional bool use_all_vocab = 34 [default = false];
+
+ ///////////////////////////////////////////////////////////////////
+ // Reserved special meta tokens.
+ // * -1 is not used.
+ // * unk_id must not be -1.
+ // Id must starts with 0 and be contigous.
+ optional int32 unk_id = 40 [default = 0]; // <unk>
+ optional int32 bos_id = 41 [default = 1]; // <s>
+ optional int32 eos_id = 42 [default = 2]; // </s>
+ optional int32 pad_id = 43 [default = -1]; // <pad> (padding)
+ optional string unk_piece = 45 [default = "<unk>"];
+ optional string bos_piece = 46 [default = "<s>"];
+ optional string eos_piece = 47 [default = "</s>"];
+ optional string pad_piece = 48 [default = "<pad>"];
+
+ // Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
+ // since this character can be useful both for user and
+ // developer. We can easily figure out that <unk> is emitted.
+ optional string unk_surface = 44 [default = " \xE2\x81\x87 "];
+
+ // Increase bit depth to allow unigram model training on large
+ // (>10M sentences) corpora. A Side-effect of enabling this flag
+ // is increased memory usage.
+ optional bool train_extremely_large_corpus = 49 [default = false];
+
+ // Customized extensions: the range of field numbers
+ // are open to third-party extensions.
+ extensions 200 to max;
+}
+
+// NormalizerSpec encodes a various parameters for string normalizaiton
+message NormalizerSpec {
+ // name of normalization rule.
+ optional string name = 1;
+
+ // Pre-compiled normalization rule created by
+ // Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method.
+ // Usually this field is set by Builder::GetNormalizerSpec() method.
+ optional bytes precompiled_charsmap = 2;
+
+ // Adds dummy whitespace at the beginning of text in order to
+ // treat "world" in "world" and "hello world" in the same way.
+ optional bool add_dummy_prefix = 3 [default = true];
+
+ // Removes leading, trailing, and duplicate internal whitespace.
+ optional bool remove_extra_whitespaces = 4 [default = true];
+
+ // Replaces whitespace with meta symbol.
+ // This field must be true to train sentence piece model.
+ optional bool escape_whitespaces = 5 [default = true];
+
+ // Custom normalization rule file in TSV format.
+ // https://github.com/google/sentencepiece/blob/master/doc/normalization.md
+ // This field is only used in SentencePieceTrainer::Train() method, which
+ // compiles the rule into the binary rule stored in `precompiled_charsmap`.
+ optional string normalization_rule_tsv = 6;
+
+ // Customized extensions: the range of field numbers
+ // are open to third-party extensions.
+ extensions 200 to max;
+}
+
+// Proto to store samples for self-testing.
+message SelfTestData {
+ message Sample {
+ optional string input = 1;
+ optional string expected = 2;
+ }
+ repeated Sample samples = 1;
+
+ // Customized extensions: the range of field numbers
+ // are open to third-party extensions.
+ extensions 200 to max;
+}
+
+// ModelProto stores model parameters.
+// SentencePieceProcessor is supposed to be self-contained.
+// All settings/parameters which may change the behavior must be encoded
+// in ModelProto.
+message ModelProto {
+ message SentencePiece {
+ enum Type {
+ NORMAL = 1; // normal symbol
+ UNKNOWN = 2; // unknown symbol. only <unk> for now.
+ CONTROL = 3; // control symbols. </s>, <s>, <2ja> etc.
+ USER_DEFINED = 4; // user defined symbols.
+ // Typical usage of USER_DEFINED symbol
+ // is placeholder.
+ BYTE = 6; // byte symbols. Used when `byte_fallback` is true.
+ UNUSED = 5; // this piece is not used.
+ }
+ optional string piece = 1; // piece must not be empty.
+ optional float score = 2;
+ optional Type type = 3 [default = NORMAL];
+
+ // Customized extensions: the range of field numbers
+ // are open to third-party extensions.
+ extensions 200 to max;
+ }
+
+ // Sentence pieces with scores.
+ repeated SentencePiece pieces = 1;
+
+ // Spec used to generate this model file.
+ optional TrainerSpec trainer_spec = 2;
+
+ // Spec for text normalization.
+ optional NormalizerSpec normalizer_spec = 3;
+
+ // Stores sample input and its expected segmentation to verify the model.
+ optional SelfTestData self_test_data = 4;
+
+ // Spec for text de-normalization.
+ optional NormalizerSpec denormalizer_spec = 5;
+
+ // Customized extensions: the range of field numbers
+ // are open to third-party extensions.
+ extensions 200 to max;
+} \ No newline at end of file
diff --git a/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def b/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def
new file mode 100644
index 00000000000..b91c0c45dc4
--- /dev/null
+++ b/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def
@@ -0,0 +1,18 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+# Configures com.yahoo.language.sentencepiece.SentencePieceEncoder
+
+namespace=language.sentencepiece
+
+# Whether consecutive unknown character should be collapsed into one large unknown token (default
+# or be returned as single character tokens.
+collapseUnknowns bool default=true
+
+# The scoring strategy to use when picking a segmentation.
+scoring enum { highestScore, fewestSegments } default=fewestSegments
+
+# The language a model is for, one of the language tags in com.yahoo.language.Language.
+# Use "unknown" for models to be used with any language.
+model[].language string
+# The path to the model relative to the application package root
+model[].path path \ No newline at end of file
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
new file mode 100644
index 00000000000..edbbe21ec53
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
@@ -0,0 +1,59 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.config.FileReference;
+import com.yahoo.language.Language;
+import org.junit.Test;
+
+/**
+ * @author bratseth
+ */
+public class SentencePieceConfigurationTest {
+
+ @Test
+ public void testEnglishTokenization() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence");
+ tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo");
+ }
+
+ @Test
+ public void testNoCollapse() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ b.collapseUnknowns(false);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
+ }
+
+ @Test
+ public void testHighestScore() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ b.scoring(SentencePieceConfig.Scoring.highestScore);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented("hello", "▁h", "el", "lo");
+ }
+
+ @Test
+ public void testMultiLanguageTokenization() {
+ var b = new SentencePieceConfig.Builder();
+ addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b);
+ addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
+ var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
+ tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
+ tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
+ }
+
+ private void addModel(String language, String file, SentencePieceConfig.Builder b) {
+ var mb = new SentencePieceConfig.Model.Builder();
+ mb.language(language);
+ mb.path(new FileReference(file));
+ b.model(mb);
+ }
+
+}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
new file mode 100644
index 00000000000..d60d7386d4b
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -0,0 +1,89 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.language.Language;
+import org.junit.Test;
+
+import java.io.File;
+
+/**
+ * @author bratseth
+ */
+public class SentencePieceTest {
+
+ @Test
+ public void testEnglishTokenization() {
+ var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ tester.assertSegmented("h", "▁h");
+ tester.assertSegmented("he", "▁he");
+ tester.assertSegmented("hel", "▁hel");
+ tester.assertSegmented("hello", "▁hel", "lo");
+ tester.assertSegmented("hei", "▁he", "i");
+ tester.assertSegmented("hei you", "▁he", "i", "▁you");
+ tester.assertSegmented("hei you", "▁he", "i", "▁you");
+ tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence");
+ tester.assertSegmented("hello world!", "▁hel", "lo", "▁world", "!");
+ tester.assertSegmented("Hello, world!", "▁", "H", "ello", ",", "▁world", "!");
+ tester.assertSegmented("HELLO, world!", "▁", "HELLO", ",", "▁world", "!");
+ tester.assertSegmented("KHJKJHHKJHHSH", "▁", "KHJKJHHKJHHSH");
+ tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo");
+ tester.assertSegmented(" hello ", "▁hel", "lo");
+ tester.assertSegmented(")(/&#()/\"\")", "▁)", "(", "/", "&", "#", "(", ")", "/", "\"", "\")");
+ tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁)", "(", "/", "&", "#", "(", "sm", "all", ")", "/", "\"", "in", "▁qu", "otes", "\")");
+ tester.assertSegmented("x.400AS", "▁x", ".", "4", "00", "AS");
+ tester.assertSegmented("A normal sentence. Yes one more.", "▁", "A", "▁normal", "▁sentence", ".", "▁", "Y", "es", "▁one", "▁more", ".");
+ }
+
+ @Test
+ public void testIntegerListEncoding() {
+ var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ tester.assertEncoded("hello, world!", 908, 1418, 9934, 501, 9960);
+ tester.assertEncoded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960);
+ }
+
+ @Test
+ public void testDenseTensorEncoding() {
+ var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ tester.assertEncoded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]");
+ tester.assertEncoded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]");
+ tester.assertEncoded("hello, world!", "tensor(d[2])", "[908,1418]");
+ }
+
+ @Test
+ public void testSparseTensorEncoding() {
+ var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ tester.assertEncoded("hello", "tensor(token{})", "{lo:1.0,'▁hel':0.0}");
+ }
+
+ @Test
+ public void testNoCollapse() {
+ var tester = new SentencePieceTester(new SentencePieceEncoder.Builder()
+ .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
+ .setCollapseUnknowns(false));
+ tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
+ }
+
+ @Test
+ public void testHighestScore() {
+ var tester = new SentencePieceTester(new SentencePieceEncoder.Builder()
+ .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
+ .setScoring(Scoring.highestScore));
+ tester.assertSegmented("h", "▁h");
+ tester.assertSegmented("he", "▁he");
+ tester.assertSegmented("hel", "▁h", "el");
+ tester.assertSegmented("hello", "▁h", "el", "lo");
+ }
+
+ @Test
+ public void testMultiLanguageTokenization() {
+ SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder();
+ builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath());
+ builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ var tester = new SentencePieceTester(builder);
+ tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
+ tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
+ tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
+ }
+
+}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
new file mode 100644
index 00000000000..1ba7c9b472d
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -0,0 +1,49 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+//
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.language.Language;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.file.Path;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+class SentencePieceTester {
+
+ private final SentencePieceEncoder encoder;
+
+ public SentencePieceTester(Path model) {
+ this(new SentencePieceEncoder.Builder().addDefaultModel(model));
+ }
+
+ public SentencePieceTester(SentencePieceEncoder.Builder builder) {
+ this(builder.build());
+ }
+
+ public SentencePieceTester(SentencePieceEncoder encoder) {
+ this.encoder = encoder;
+ }
+
+ public void assertEncoded(String input, Integer... expectedCodes) {
+ assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
+ }
+
+ public void assertEncoded(String input, String tensorType, String tensor) {
+ TensorType type = TensorType.fromSpec(tensorType);
+ Tensor expected = Tensor.from(type, tensor);
+ assertEquals(expected, encoder.encode(input, Language.UNKNOWN, type));
+ }
+
+ public void assertSegmented(String input, String... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+
+ public void assertSegmented(Language language, String input, String... expectedSegments) {
+ assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
+ }
+
+}
diff --git a/linguistics-components/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model b/linguistics-components/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model
new file mode 100644
index 00000000000..89f93ef3517
--- /dev/null
+++ b/linguistics-components/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model
Binary files differ
diff --git a/linguistics-components/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model b/linguistics-components/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model
new file mode 100644
index 00000000000..41c0688d9df
--- /dev/null
+++ b/linguistics-components/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model
Binary files differ