diff options
author | Jon Bratseth <bratseth@oath.com> | 2021-09-27 23:09:03 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-27 23:09:03 +0200 |
commit | 2df97d23d9f25ae60f010a2e9f273cb5b38e049b (patch) | |
tree | d2923a45682e91d80e7011c60cfb301e05acead3 /linguistics-components/src | |
parent | 037f756caf4cfb99bcd988174839d7bc385267b9 (diff) | |
parent | 8f3fb1a105ded07144f6de527266a438e48a1766 (diff) |
Merge pull request #19294 from vespa-engine/bratseth/linguistics-componentsv7.473.17
Bratseth/linguistics components
Diffstat (limited to 'linguistics-components/src')
15 files changed, 1015 insertions, 0 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java new file mode 100644 index 00000000000..74f300057dc --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java @@ -0,0 +1,60 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +import com.yahoo.io.IOUtils; +import com.yahoo.language.Language; +import sentencepiece.SentencepieceModel; + +import java.io.IOException; +import java.nio.file.Path; + +/** + * A SentencePiece model + * + * @author bratseth + */ +final class Model { + + final Path source; + final Language language; + final float minScore; + final float maxScore; + final Trie tokens = new Trie(); + + Model(Language language, Path path) { + try { + this.source = path; + this.language = language; + var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile())); + float minScore = Float.MAX_VALUE; + float maxScore = Float.MIN_VALUE; + for (int i = 0; i < sp.getPiecesCount(); i++) { + var piece = sp.getPieces(i); + tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); + minScore = Math.min(piece.getScore(), minScore); + maxScore = Math.max(piece.getScore(), maxScore); + } + this.minScore = minScore; + this.maxScore = maxScore; + } catch (IOException e) { + throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e); + } + } + + private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) { + switch (type) { + case USER_DEFINED : return TokenType.userDefined; + case UNKNOWN : return TokenType.unknown; + case NORMAL : return TokenType.text; + case CONTROL : return TokenType.control; + case UNUSED : return TokenType.unused; + default : throw new IllegalArgumentException("Unknkown token type " + type); + } + } + + @Override + public String toString() { + return "SentencePiece model for " + language + ": '" + source + "'"; + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java new file mode 100644 index 00000000000..2141505374c --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java @@ -0,0 +1,47 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * Builds a result from a sentencepiece tokenization by being called for each segment in reverse + * + * @param <RESULTTYPE> the type of result this produces + * @author bratseth + */ +abstract class ResultBuilder<RESULTTYPE> { + + private final RESULTTYPE result; + + ResultBuilder(RESULTTYPE result) { + this.result = result; + } + + /** Called for each segment, starting from the last and working backwards */ + abstract void add(int start, int end, SentencePieceAlgorithm.SegmentEnd[] segmentEnds); + + RESULTTYPE result() {return result;} + + void build(String input, SentencePieceAlgorithm.SegmentEnd[] segmentEnds, boolean collapseUnknowns) { + if (collapseUnknowns) { + int segmentEnd = input.length(); + int collapsedSegmentEnd = segmentEnd; + while (segmentEnd > 0) { + if (segmentEnds[segmentEnd].type != TokenType.unknown ) { + if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment + add(segmentEnd, collapsedSegmentEnd, segmentEnds); + } + add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); + collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart; + } + segmentEnd = segmentEnds[segmentEnd].segmentStart; + } + } + else { + int segmentEnd = input.length(); + while (segmentEnd > 0) { + add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); + segmentEnd = segmentEnds[segmentEnd].segmentStart; + } + } + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java new file mode 100644 index 00000000000..6c8560abee7 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java @@ -0,0 +1,17 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * The scoring strategy to use for picking segments + * + * @author bratseth + */ +public enum Scoring { + + /** Find the segmentation that has the highest score */ + highestScore, + + /** Find the segmentation that has the fewest segments, resolve ties by score sum */ + fewestSegments + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java new file mode 100644 index 00000000000..1659e3c0fa7 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java @@ -0,0 +1,90 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * SentencePiece algorithm implementation + * + * @author bratseth + */ +class SentencePieceAlgorithm { + + // TODO: Support characters beyond BMP + + static final char spaceSymbol = '▁'; + + private final boolean collapseUnknowns; + private final Scoring scoring; + + SentencePieceAlgorithm(boolean collapseUnknowns, Scoring scoring) { + this.collapseUnknowns = collapseUnknowns; + this.scoring = scoring; + } + + public <RESULTTYPE> void segment(String input, ResultBuilder<RESULTTYPE> resultBuilder, Model model) { + SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; + segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0); + int start = 0; + while (start < input.length()) { // segment from this position to the end of the text + Trie.Node node = model.tokens.root; + int characterPosition = start; + while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position + node = node.children.get(input.charAt(characterPosition++)); + int length = characterPosition - start; + if (node != null && node.isToken() && node.type != TokenType.unused) { + float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score; + addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds); + } + else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable + addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds); + } + } + start++; + } + resultBuilder.build(input, segmentEnds, collapseUnknowns); + } + + private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) { + if (segmentEnds[end] == null || + segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) { + segmentEnds[end] = new SegmentEnd(type, id, + segmentEnds[start].pathScoreSum + score, + segmentEnds[start].pathSegmentCount + 1, + start); + } + } + + final class SegmentEnd { + + final TokenType type; + final int id; + final float pathScoreSum; + final int pathSegmentCount; + final int segmentStart; + + SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) { + this.type = type; + this.id = id; + this.pathScoreSum = pathScoreSum; + this.pathSegmentCount = pathSegmentCount; + this.segmentStart = segmentStart; + } + + public float score() { + switch (scoring) { + case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum; + case highestScore: return pathScoreSum; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + public float scoreWith(float additionalSegmentScore) { + switch (scoring) { + case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore ); + case highestScore: return pathScoreSum + additionalSegmentScore; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java new file mode 100644 index 00000000000..b6659ebeaa3 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -0,0 +1,220 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +import com.google.common.annotations.Beta; +import com.google.inject.Inject; +import com.yahoo.language.Language; +import com.yahoo.language.process.Encoder; +import com.yahoo.language.process.Segmenter; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Integration with https://github.com/google/sentencepiece + * through http://docs.djl.ai/extensions/sentencepiece/index.html + * + * SentencePiece is a language-agnostic tokenizer for neural nets. + * + * @author bratseth + */ +@Beta +public class SentencePieceEncoder implements Segmenter, Encoder { + + private final Map<Language, Model> models; + + private final SentencePieceAlgorithm algorithm; + + @Inject + public SentencePieceEncoder(SentencePieceConfig config) { + this(new Builder(config)); + } + + public SentencePieceEncoder(Builder builder) { + algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring()); + + models = builder.getModels().entrySet() + .stream() + .map(e -> new Model(e.getKey(), e.getValue())) + .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); + if (models.isEmpty()) + throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured"); + } + + /** + * Segments the given text into token segments using the SentencePiece algorithm + * + * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. + * @param language the model to use, or Language.UNKNOWN to use the default model if any + * @return the list of zero or more tokens resulting from segmenting the input text + */ + @Override + public List<String> segment(String rawInput, Language language) { + String input = normalize(rawInput); + var resultBuilder = new ResultBuilder<List<String>>(new ArrayList<>()) { + public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { + result().add(input.substring(segmentStart, segmentEnd)); + } + }; + segment(input, language, resultBuilder); + Collections.reverse(resultBuilder.result()); + return resultBuilder.result(); + } + + /** + * Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids. + * + * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. + * @param language the model to use, or Language.UNKNOWN to use the default model if any + * @return the list of zero or more token ids resulting from segmenting the input text + */ + @Override + public List<Integer> encode(String rawInput, Language language) { + var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) { + public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { + result().add(segmentEnds[segmentEnd].id); + } + }; + segment(normalize(rawInput), language, resultBuilder); + Collections.reverse(resultBuilder.result()); + return resultBuilder.result(); + } + + /** + * <p>Encodes directly to a tensor.</p> + * + * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order + * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small + * it will be truncated.</p> + * + * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token + * position as value.</p> + * + * <p>If the tensor is any other type IllegalArgumentException is thrown.</p> + */ + @Override + public Tensor encode(String rawInput, Language language, TensorType type) { + if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { + // Build to a list first since we can't reverse a tensor builder + List<Integer> values = encode(rawInput, language); + + long maxSize = values.size(); + if (type.dimensions().get(0).size().isPresent()) + maxSize = Math.min(maxSize, type.dimensions().get(0).size().get()); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int i = 0; i < maxSize; i++) + builder.cell(values.get(i), i); + return builder.build(); + } + else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) { + // Build to a list first since we can't reverse a tensor builder + List<String> values = segment(rawInput, language); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int i = 0; i < values.size(); i++) + builder.cell(TensorAddress.ofLabels(values.get(i)), i); + return builder.build(); + } + else { + throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type); + } + } + + private <RESULTTYPE> void segment(String input, Language language, + ResultBuilder<RESULTTYPE> resultBuilder) { + Model model = resolveFrom(language); + algorithm.segment(input, resultBuilder, model); + } + + private Model resolveFrom(Language language) { + // Disregard language if there is default model + if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); + if (models.containsKey(language)) return models.get(language); + throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured"); + } + + public String normalize(String s) { + StringBuilder b = new StringBuilder(s.length() + 1); + boolean queuedSpace = true; // Always start by one space + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if (s.charAt(i) == ' ') { + queuedSpace = true; + } + else { + if (queuedSpace) { + b.append(SentencePieceAlgorithm.spaceSymbol); + queuedSpace = false; + } + b.append(c); + } + } + return b.toString(); + } + + public static class Builder { + + private final Map<Language, Path> models = new HashMap<>(); + private boolean collapseUnknowns = true; + private Scoring scoring = Scoring.fewestSegments; + + public Builder() { + } + + private Builder(SentencePieceConfig config) { + collapseUnknowns = config.collapseUnknowns(); + scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments + : Scoring.highestScore; + for (SentencePieceConfig.Model model : config.model()) { + addModel(Language.fromLanguageTag(model.language()), model.path()); + } + } + + public void addModel(Language language, Path model) { + models.put(language, model); + } + + /** + * Adds the model that will be used if the language is unknown, OR only one model is specified. + * The same as addModel(Language.UNKNOWN, model). + */ + public Builder addDefaultModel(Path model) { + addModel(Language.UNKNOWN, model); + return this; + } + public Map<Language, Path> getModels() { return models; } + + /** + * Sets whether consecutive unknown character should be collapsed into one large unknown token (default) + * or be returned as single character tokens. + */ + public Builder setCollapseUnknowns(boolean collapseUnknowns) { + this.collapseUnknowns = collapseUnknowns; + return this; + } + public boolean getCollapseUnknowns() { return collapseUnknowns; } + + /** Sets the scoring strategy to use when picking a segmentation. Default: fewestSegments. */ + public Builder setScoring(Scoring scoring) { + this.scoring = scoring; + return this; + } + public Scoring getScoring() { return scoring; } + + public SentencePieceEncoder build() { + if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); + return new SentencePieceEncoder(this); + } + + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java new file mode 100644 index 00000000000..782030a8e4d --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java @@ -0,0 +1,13 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * SentencePiece token types + * + * @author bratseth + */ +enum TokenType { + + text, control, userDefined, unknown, unused + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java new file mode 100644 index 00000000000..8e7c2db2ed3 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java @@ -0,0 +1,36 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +import java.util.HashMap; +import java.util.Map; + +/** + * A simple trie for sentencepiece token lookups. + * + * @author bratseth + */ +class Trie { + + final Node root = new Node(); + + void add(TokenType type, int id, String word, float score) { + Node current = root; + for (char l : word.toCharArray()) + current = current.children.computeIfAbsent(l, c -> new Node()); + current.type = type; + current.id = id; + current.score = score; + } + + static class Node { + + Integer id; + TokenType type; + Float score; + final Map<Character, Node> children = new HashMap<>(); + + boolean isToken() { return type != null; } + + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java new file mode 100644 index 00000000000..3f97277c489 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java @@ -0,0 +1,7 @@ +// Copyright 2021 Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package com.yahoo.language.sentencepiece; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/linguistics-components/src/main/protobuf/sentencepiece_model.proto b/linguistics-components/src/main/protobuf/sentencepiece_model.proto new file mode 100644 index 00000000000..39626aede53 --- /dev/null +++ b/linguistics-components/src/main/protobuf/sentencepiece_model.proto @@ -0,0 +1,310 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +syntax = "proto2"; + +// TODO(taku): Needs to use LITE RUNTIME in OSS release. +option optimize_for = LITE_RUNTIME; + +package sentencepiece; + +// TrainerSpec encodes a various parameters for SentencePiece training. +message TrainerSpec { + /////////////////////////////////////////////////////////////////// + // General parameters + // + // Input corpus files. + // Trainer accepts the following two formats: + // A) Monolingual: plain text, one sentence per line. + // B) Bilingual: TSV, source sentence <tab> target sentence + // When bilingual data is passed, shared vocabulary model is built. + // Note that the input file must be raw corpus, not a preprocessed corpus. + // Trainer only loads the first `input_sentence_size` sentences specified + // with this parameter. + repeated string input = 1; + + // Input corpus format: + // "text": one-sentence-per-line text format (default) + // "tsv": sentence <tab> freq + optional string input_format = 7; + + // Output model file prefix. + // <model_prefix>.model and <model_prefix>.vocab are generated. + optional string model_prefix = 2; + + // Model type. only have UNIGRAM now. + enum ModelType { + UNIGRAM = 1; // Unigram language model with dynamic algorithm + BPE = 2; // Byte Pair Encoding + WORD = 3; // Delimitered by whitespace. + CHAR = 4; // tokenizes into character sequence + } + optional ModelType model_type = 3 [default = UNIGRAM]; + + // Vocabulary size. 8k is the default size. + optional int32 vocab_size = 4 [default = 8000]; + + // List of the languages this model can accept. + // Since the model is language-agnostic, this field is used as a reference. + repeated string accept_language = 5; + + // Size of self-test samples, which are encoded in the model file. + optional int32 self_test_sample_size = 6 [default = 0]; + + /////////////////////////////////////////////////////////////////// + // Training parameters. + // + // Uses characters which cover the corpus with the ratio of `chars_coverage`. + // This parameter determines the set of basic Alphabet of sentence piece. + // 1.0 - `chars_coverage` characters are treated as UNK. + // See also required_chars field. + optional float character_coverage = 10 [default = 0.9995]; + + // Maximum size of sentences the trainer loads from `input` parameter. + // Trainer simply loads the `input` files in sequence. + // It is better to shuffle the input corpus randomly. + optional uint64 input_sentence_size = 11 [default = 0]; + optional bool shuffle_input_sentence = 19 [default = true]; + + // Maximum size of sentences to make seed sentence pieces. + // Extended suffix array is constructed to extract frequent + // sub-strings from the corpus. This uses 20N working space, + // where N is the size of corpus. + optional int32 mining_sentence_size = 12 [deprecated = true]; + + // Maximum size of sentences to train sentence pieces. + optional int32 training_sentence_size = 13 [deprecated = true]; + + // The size of seed sentencepieces. + // `seed_sentencepiece_size` must be larger than `vocab_size`. + optional int32 seed_sentencepiece_size = 14 [default = 1000000]; + + // In every EM sub-iterations, keeps top + // `shrinking_factor` * `current sentencepieces size` with respect to + // the loss of the sentence piece. This value should be smaller than 1.0. + optional float shrinking_factor = 15 [default = 0.75]; + + // The maximum sentence length in byte. The sentences with the length + // larger than `max_sentence_length` is simply ignored. + // Longer input tends to bring the following risks: + // * Overflow during EM training (unigram language model only) + // * Performance drop because of O(n log n) cost in BPE. + optional int32 max_sentence_length = 18 [default = 4192]; + + // Number of threads in the training. + optional int32 num_threads = 16 [default = 16]; + + // Number of EM sub iterations. + optional int32 num_sub_iterations = 17 [default = 2]; + + /////////////////////////////////////////////////////////////////// + // SentencePiece parameters which control the shapes of sentence piece. + // + // Maximum length of sentencepiece. + optional int32 max_sentencepiece_length = 20 [default = 16]; + + // Uses Unicode script to split sentence pieces. + // When `split_by_unicode_script` is true, we do not allow sentence piece to + // include multiple Unicode scripts, e.g. "F1" is not a valid piece. + // Exception: CJ characters (Hiragana/Katakana/Han) are all handled + // as one script type, since Japanese word can consist of multiple scripts. + // This exception is always applied regardless of the accept-language + // parameter. + optional bool split_by_unicode_script = 21 [default = true]; + + // When `split_by_number` is true, put a boundary between number and + // non-number transition. If we want to treat "F1" is one token, set this flag + // to be false. + optional bool split_by_number = 23 [default = true]; + + // Use a white space to split sentence pieces. + // When `split_by_whitespace` is false, we may have the piece containing + // a white space in the middle. e.g., "in_the". + optional bool split_by_whitespace = 22 [default = true]; + + // Adds whitespace symbol (_) as a suffix instead of prefix. e.g., _hello => + // hello_. When `treat_whitespace_as_suffix` is true, + // NormalizerSpec::add_dummy_prefix will add the dummy whitespace to the end + // of sentence. + optional bool treat_whitespace_as_suffix = 24 [default = false]; + + // Allows pieces that only contain whitespaces instead of appearing only as + // prefix or suffix of other pieces. + optional bool allow_whitespace_only_pieces = 26 [default = false]; + + // Split all digits (0-9) into separate pieces. + optional bool split_digits = 25 [default = false]; + + /////////////////////////////////////////////////////////////////// + // Vocabulary management + // + // Defines control symbols used as an indicator to + // change the behavior of the decoder. <s> and </s> are pre-defined. + // We can use this field to encode various meta information, + // including language indicator in multilingual model. + // These symbols are not visible to users, but visible to + // the decoder. Note that when the input sentence contains control symbols, + // they are not treated as one token, but segmented into normal pieces. + // Control symbols must be inserted independently from the segmentation. + repeated string control_symbols = 30; + + // Defines user defined symbols. + // These symbols are added with extremely high score + // so they are always treated as one unique symbol in any context. + // Typical usage of user_defined_symbols is placeholder for named entities. + repeated string user_defined_symbols = 31; + + // Defines required characters. Each UTF8 character in this string is included + // in the character set regardless of character_coverage value. Unlike + // user_defined_symbols, these characters have scores based on the frequency + // on input sentences, and the model can form subwords using characters + // in this field. + optional string required_chars = 36; + + // Decomposes unknown pieces into UTF-8 bytes. + optional bool byte_fallback = 35 [default = false]; + + // When creating the vocabulary file, defines whether or not to additionally + // output the score for each piece. + optional bool vocabulary_output_piece_score = 32 [default = true]; + + // `vocab_size` is treated as hard limit. Crash if + // the model can not produce the vocab of size `vocab_size`, + // When `hard_vocab_limit` is false, vocab_size is treated + // as soft limit. Note that when model_type=char, + // always assumes hard_vocab_limit = false. + optional bool hard_vocab_limit = 33 [default = true]; + + // use all symbols for vocab extraction. This flag is valid + // if model type is either CHAR or WORD + optional bool use_all_vocab = 34 [default = false]; + + /////////////////////////////////////////////////////////////////// + // Reserved special meta tokens. + // * -1 is not used. + // * unk_id must not be -1. + // Id must starts with 0 and be contigous. + optional int32 unk_id = 40 [default = 0]; // <unk> + optional int32 bos_id = 41 [default = 1]; // <s> + optional int32 eos_id = 42 [default = 2]; // </s> + optional int32 pad_id = 43 [default = -1]; // <pad> (padding) + optional string unk_piece = 45 [default = "<unk>"]; + optional string bos_piece = 46 [default = "<s>"]; + optional string eos_piece = 47 [default = "</s>"]; + optional string pad_piece = 48 [default = "<pad>"]; + + // Encodes <unk> into U+2047 (DOUBLE QUESTION MARK), + // since this character can be useful both for user and + // developer. We can easily figure out that <unk> is emitted. + optional string unk_surface = 44 [default = " \xE2\x81\x87 "]; + + // Increase bit depth to allow unigram model training on large + // (>10M sentences) corpora. A Side-effect of enabling this flag + // is increased memory usage. + optional bool train_extremely_large_corpus = 49 [default = false]; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// NormalizerSpec encodes a various parameters for string normalizaiton +message NormalizerSpec { + // name of normalization rule. + optional string name = 1; + + // Pre-compiled normalization rule created by + // Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method. + // Usually this field is set by Builder::GetNormalizerSpec() method. + optional bytes precompiled_charsmap = 2; + + // Adds dummy whitespace at the beginning of text in order to + // treat "world" in "world" and "hello world" in the same way. + optional bool add_dummy_prefix = 3 [default = true]; + + // Removes leading, trailing, and duplicate internal whitespace. + optional bool remove_extra_whitespaces = 4 [default = true]; + + // Replaces whitespace with meta symbol. + // This field must be true to train sentence piece model. + optional bool escape_whitespaces = 5 [default = true]; + + // Custom normalization rule file in TSV format. + // https://github.com/google/sentencepiece/blob/master/doc/normalization.md + // This field is only used in SentencePieceTrainer::Train() method, which + // compiles the rule into the binary rule stored in `precompiled_charsmap`. + optional string normalization_rule_tsv = 6; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// Proto to store samples for self-testing. +message SelfTestData { + message Sample { + optional string input = 1; + optional string expected = 2; + } + repeated Sample samples = 1; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// ModelProto stores model parameters. +// SentencePieceProcessor is supposed to be self-contained. +// All settings/parameters which may change the behavior must be encoded +// in ModelProto. +message ModelProto { + message SentencePiece { + enum Type { + NORMAL = 1; // normal symbol + UNKNOWN = 2; // unknown symbol. only <unk> for now. + CONTROL = 3; // control symbols. </s>, <s>, <2ja> etc. + USER_DEFINED = 4; // user defined symbols. + // Typical usage of USER_DEFINED symbol + // is placeholder. + BYTE = 6; // byte symbols. Used when `byte_fallback` is true. + UNUSED = 5; // this piece is not used. + } + optional string piece = 1; // piece must not be empty. + optional float score = 2; + optional Type type = 3 [default = NORMAL]; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; + } + + // Sentence pieces with scores. + repeated SentencePiece pieces = 1; + + // Spec used to generate this model file. + optional TrainerSpec trainer_spec = 2; + + // Spec for text normalization. + optional NormalizerSpec normalizer_spec = 3; + + // Stores sample input and its expected segmentation to verify the model. + optional SelfTestData self_test_data = 4; + + // Spec for text de-normalization. + optional NormalizerSpec denormalizer_spec = 5; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +}
\ No newline at end of file diff --git a/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def b/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def new file mode 100644 index 00000000000..b91c0c45dc4 --- /dev/null +++ b/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def @@ -0,0 +1,18 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +# Configures com.yahoo.language.sentencepiece.SentencePieceEncoder + +namespace=language.sentencepiece + +# Whether consecutive unknown character should be collapsed into one large unknown token (default +# or be returned as single character tokens. +collapseUnknowns bool default=true + +# The scoring strategy to use when picking a segmentation. +scoring enum { highestScore, fewestSegments } default=fewestSegments + +# The language a model is for, one of the language tags in com.yahoo.language.Language. +# Use "unknown" for models to be used with any language. +model[].language string +# The path to the model relative to the application package root +model[].path path
\ No newline at end of file diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java new file mode 100644 index 00000000000..edbbe21ec53 --- /dev/null +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java @@ -0,0 +1,59 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.sentencepiece; + +import com.yahoo.config.FileReference; +import com.yahoo.language.Language; +import org.junit.Test; + +/** + * @author bratseth + */ +public class SentencePieceConfigurationTest { + + @Test + public void testEnglishTokenization() { + var b = new SentencePieceConfig.Builder(); + addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); + tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); + } + + @Test + public void testNoCollapse() { + var b = new SentencePieceConfig.Builder(); + addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + b.collapseUnknowns(false); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); + } + + @Test + public void testHighestScore() { + var b = new SentencePieceConfig.Builder(); + addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + b.scoring(SentencePieceConfig.Scoring.highestScore); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented("hello", "▁h", "el", "lo"); + } + + @Test + public void testMultiLanguageTokenization() { + var b = new SentencePieceConfig.Builder(); + addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b); + addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); + var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); + tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); + tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); + } + + private void addModel(String language, String file, SentencePieceConfig.Builder b) { + var mb = new SentencePieceConfig.Model.Builder(); + mb.language(language); + mb.path(new FileReference(file)); + b.model(mb); + } + +} diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java new file mode 100644 index 00000000000..d60d7386d4b --- /dev/null +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -0,0 +1,89 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.sentencepiece; + +import com.yahoo.language.Language; +import org.junit.Test; + +import java.io.File; + +/** + * @author bratseth + */ +public class SentencePieceTest { + + @Test + public void testEnglishTokenization() { + var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + tester.assertSegmented("h", "▁h"); + tester.assertSegmented("he", "▁he"); + tester.assertSegmented("hel", "▁hel"); + tester.assertSegmented("hello", "▁hel", "lo"); + tester.assertSegmented("hei", "▁he", "i"); + tester.assertSegmented("hei you", "▁he", "i", "▁you"); + tester.assertSegmented("hei you", "▁he", "i", "▁you"); + tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); + tester.assertSegmented("hello world!", "▁hel", "lo", "▁world", "!"); + tester.assertSegmented("Hello, world!", "▁", "H", "ello", ",", "▁world", "!"); + tester.assertSegmented("HELLO, world!", "▁", "HELLO", ",", "▁world", "!"); + tester.assertSegmented("KHJKJHHKJHHSH", "▁", "KHJKJHHKJHHSH"); + tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); + tester.assertSegmented(" hello ", "▁hel", "lo"); + tester.assertSegmented(")(/&#()/\"\")", "▁)", "(", "/", "&", "#", "(", ")", "/", "\"", "\")"); + tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁)", "(", "/", "&", "#", "(", "sm", "all", ")", "/", "\"", "in", "▁qu", "otes", "\")"); + tester.assertSegmented("x.400AS", "▁x", ".", "4", "00", "AS"); + tester.assertSegmented("A normal sentence. Yes one more.", "▁", "A", "▁normal", "▁sentence", ".", "▁", "Y", "es", "▁one", "▁more", "."); + } + + @Test + public void testIntegerListEncoding() { + var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + tester.assertEncoded("hello, world!", 908, 1418, 9934, 501, 9960); + tester.assertEncoded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960); + } + + @Test + public void testDenseTensorEncoding() { + var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + tester.assertEncoded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]"); + tester.assertEncoded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]"); + tester.assertEncoded("hello, world!", "tensor(d[2])", "[908,1418]"); + } + + @Test + public void testSparseTensorEncoding() { + var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + tester.assertEncoded("hello", "tensor(token{})", "{lo:1.0,'▁hel':0.0}"); + } + + @Test + public void testNoCollapse() { + var tester = new SentencePieceTester(new SentencePieceEncoder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setCollapseUnknowns(false)); + tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); + } + + @Test + public void testHighestScore() { + var tester = new SentencePieceTester(new SentencePieceEncoder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setScoring(Scoring.highestScore)); + tester.assertSegmented("h", "▁h"); + tester.assertSegmented("he", "▁he"); + tester.assertSegmented("hel", "▁h", "el"); + tester.assertSegmented("hello", "▁h", "el", "lo"); + } + + @Test + public void testMultiLanguageTokenization() { + SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder(); + builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); + builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + var tester = new SentencePieceTester(builder); + tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); + tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); + tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); + } + +} diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java new file mode 100644 index 00000000000..1ba7c9b472d --- /dev/null +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java @@ -0,0 +1,49 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// + +package com.yahoo.language.sentencepiece; + +import com.yahoo.language.Language; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.nio.file.Path; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +class SentencePieceTester { + + private final SentencePieceEncoder encoder; + + public SentencePieceTester(Path model) { + this(new SentencePieceEncoder.Builder().addDefaultModel(model)); + } + + public SentencePieceTester(SentencePieceEncoder.Builder builder) { + this(builder.build()); + } + + public SentencePieceTester(SentencePieceEncoder encoder) { + this.encoder = encoder; + } + + public void assertEncoded(String input, Integer... expectedCodes) { + assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray()); + } + + public void assertEncoded(String input, String tensorType, String tensor) { + TensorType type = TensorType.fromSpec(tensorType); + Tensor expected = Tensor.from(type, tensor); + assertEquals(expected, encoder.encode(input, Language.UNKNOWN, type)); + } + + public void assertSegmented(String input, String... expectedSegments) { + assertSegmented(Language.UNKNOWN, input, expectedSegments); + } + + public void assertSegmented(Language language, String input, String... expectedSegments) { + assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray()); + } + +} diff --git a/linguistics-components/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model b/linguistics-components/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model Binary files differnew file mode 100644 index 00000000000..89f93ef3517 --- /dev/null +++ b/linguistics-components/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model diff --git a/linguistics-components/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model b/linguistics-components/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model Binary files differnew file mode 100644 index 00000000000..41c0688d9df --- /dev/null +++ b/linguistics-components/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model |