diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-09-13 19:29:36 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-09-13 19:29:36 +0200 |
commit | bfae6ba4377ea51a7cc024c5e9772ed069feedf4 (patch) | |
tree | ff20b7491674e60a9851ee4bdb576677bb4fc486 /linguistics | |
parent | dd26caaa5cc05e9cacaa280a4bee5d9ddb56ecbc (diff) |
Pure Java sentencepiece implementation
Diffstat (limited to 'linguistics')
-rw-r--r-- | linguistics/pom.xml | 8 | ||||
-rw-r--r-- | linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java | 332 | ||||
-rw-r--r-- | linguistics/src/main/java/com/yahoo/language/simple/SimpleToken.java | 5 | ||||
-rw-r--r-- | linguistics/src/main/protobuf/sentencepiece_model.proto | 310 | ||||
-rw-r--r-- | linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java | 78 | ||||
-rw-r--r-- | linguistics/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model | bin | 0 -> 400869 bytes | |||
-rw-r--r-- | linguistics/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model | bin | 0 -> 300865 bytes |
7 files changed, 731 insertions, 2 deletions
diff --git a/linguistics/pom.xml b/linguistics/pom.xml index 3cc430dacc6..221d7181616 100644 --- a/linguistics/pom.xml +++ b/linguistics/pom.xml @@ -15,6 +15,10 @@ <version>7-SNAPSHOT</version> <dependencies> <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </dependency> + <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <scope>test</scope> @@ -69,6 +73,10 @@ <build> <plugins> <plugin> + <groupId>com.github.os72</groupId> + <artifactId>protoc-jar-maven-plugin</artifactId> + </plugin> + <plugin> <groupId>com.yahoo.vespa</groupId> <artifactId>bundle-plugin</artifactId> <extensions>true</extensions> diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java new file mode 100644 index 00000000000..9509c1d070d --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -0,0 +1,332 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.sentencepiece; + +import com.google.common.annotations.Beta; +import com.yahoo.io.IOUtils; +import com.yahoo.language.Language; +import com.yahoo.language.process.Segmenter; +import sentencepiece.SentencepieceModel; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +/** + * Integration with https://github.com/google/sentencepiece + * through http://docs.djl.ai/extensions/sentencepiece/index.html + * + * SentencePiece is a language-agnostic tokenizer for neural nets. + * + * @author bratseth + */ +@Beta +public class SentencePieceEncoder implements Segmenter { + + // TODO: Support characters beyond BMP + + public enum TokenType { text, control, userDefined, unknown, unused } + + /** The scoring strategy to use for picking segments */ + public enum Scoring { + /** Find the segmentation that has the highest score */ + highestScore, + /** Find the segmentation that has the fewest segments, resolve ties by score sum */ + fewestSegments + } + + private static final char spaceSymbol = '▁'; + + private final boolean collapseUnknowns; + private final Scoring scoring; + + private final Map<Language, Model> models; + + public SentencePieceEncoder(Builder builder) { + collapseUnknowns = builder.getCollapseUnknowns(); + scoring = builder.getScoring(); + + models = builder.getModels().entrySet() + .stream() + .map(e -> new Model(e.getKey(), e.getValue())) + .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); + } + + /** + * Segments the given text into token segments using the SentencePiece algorithm + * + * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. + * @param language the model to use, or Language.UNKNOWN to use the default model if any + * @return the list of zero or more tokens resulting from segmenting the input text + */ + @Override + public List<String> segment(String rawInput, Language language) { + String input = normalize(rawInput); + SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; + return segment(input, language, segmentEnds, (segmentStart, segmentEnd) -> input.substring(segmentStart, segmentEnd)); + } + + /** + * Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids. + * + * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. + * @param language the model to use, or Language.UNKNOWN to use the default model if any + * @return the list of zero or more token ids resulting from segmenting the input text + */ + public List<Integer> encode(String rawInput, Language language) { + String input = normalize(rawInput); + SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; + return segment(input, language, segmentEnds, (segmentStart, segmentEnd) -> segmentEnds[segmentEnd].id); + } + + private <ITEMTYPE> List<ITEMTYPE> segment(String input, Language language, + SegmentEnd[] segmentEnds, + BiFunction<Integer, Integer, ITEMTYPE> resultItemMapper) { + Model model = resolveFrom(language); + float unknownScore = model.minScore - 10.0f; + segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0); + + int start = 0; + while (start < input.length()) { // segment from this position to the end of the text + Trie.Node node = model.tokens.root; + int characterPosition = start; + boolean addedSingleCharacterSegment = false; + while (characterPosition < input.length()) { // traverse the trie one character at the time from this position + node = node.children.get(input.charAt(characterPosition)); + characterPosition++; + if (node == null) break; + int length = characterPosition - start; + if (node.isToken()) { + if (node.type == TokenType.unused) continue; + + float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score; + addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds); + } + if (! addedSingleCharacterSegment && length == 1) + addedSingleCharacterSegment = true; + } + if ( ! addedSingleCharacterSegment) // add an unknown 1 character token to be able to start from the next character + addSegment(TokenType.unknown, 0, start, start + 1, unknownScore, segmentEnds); + start++; + } + + return createResult(input, segmentEnds, resultItemMapper); + } + + private Model resolveFrom(Language language) { + // Disregard language if there is default model + if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); + if (models.containsKey(language)) return models.get(language); + throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured"); + } + + private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) { + if (segmentEnds[end] == null || + segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) { + segmentEnds[end] = new SegmentEnd(type, id, + segmentEnds[start].pathScoreSum + score, + segmentEnds[start].pathSegmentCount + 1, + start); + } + } + + private <ITEMTYPE> List<ITEMTYPE> createResult(String input, SegmentEnd[] segmentEnds, + BiFunction<Integer, Integer, ITEMTYPE> resultItemMapper) { + List<ITEMTYPE> result = new ArrayList<>(); + if (collapseUnknowns) { + int segmentEnd = input.length(); + int collapsedSegmentEnd = segmentEnd; + while (segmentEnd > 0) { + if (segmentEnds[segmentEnd].type != TokenType.unknown ) { + if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment + result.add(resultItemMapper.apply(segmentEnd, collapsedSegmentEnd)); + } + result.add(resultItemMapper.apply(segmentEnds[segmentEnd].segmentStart, segmentEnd)); + collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart; + } + segmentEnd = segmentEnds[segmentEnd].segmentStart; + } + } + else { + int segmentEnd = input.length(); + while (segmentEnd > 0) { + result.add(resultItemMapper.apply(segmentEnds[segmentEnd].segmentStart, segmentEnd)); + segmentEnd = segmentEnds[segmentEnd].segmentStart; + } + } + Collections.reverse(result); + return result; + } + + private final class SegmentEnd { + + final TokenType type; + final int id; + final float pathScoreSum; + final int pathSegmentCount; + final int segmentStart; + + SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) { + this.type = type; + this.id = id; + this.pathScoreSum = pathScoreSum; + this.pathSegmentCount = pathSegmentCount; + this.segmentStart = segmentStart; + } + + public float score() { + switch (scoring) { + case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum; + case highestScore: return pathScoreSum; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + public float scoreWith(float additionalSegmentScore) { + switch (scoring) { + case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore ); + case highestScore: return pathScoreSum + additionalSegmentScore; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + } + + public String normalize(String s) { + StringBuilder b = new StringBuilder(s.length() + 1); + boolean queuedSpace = true; // Always start by one space + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if (s.charAt(i) == ' ') { + queuedSpace = true; + } + else { + if (queuedSpace) { + b.append(spaceSymbol); + queuedSpace = false; + } + b.append(c); + } + } + return b.toString(); + } + + private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) { + switch (type) { + case USER_DEFINED : return TokenType.userDefined; + case UNKNOWN : return TokenType.unknown; + case NORMAL : return TokenType.text; + case CONTROL : return TokenType.control; + case UNUSED : return TokenType.unused; + default : throw new IllegalArgumentException("Unknkown token type " + type); + } + } + + private static class Trie { + + final Node root = new Node(); + + void add(TokenType type, int id, String word, float score) { + Node current = root; + for (char l : word.toCharArray()) + current = current.children.computeIfAbsent(l, c -> new Node()); + current.type = type; + current.id = id; + current.score = score; + } + + static class Node { + + Integer id; + TokenType type; + Float score; + private final Map<Character, Node> children = new HashMap<>(); + + boolean isToken() { return score != null; } + + } + + } + + private static final class Model { + + final Language language; + final float minScore; + final float maxScore; + final Trie tokens = new Trie(); + + Model(Language language, Path path) { + try { + this.language = language; + var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile())); + float minScore = Float.MAX_VALUE; + float maxScore = Float.MIN_VALUE; + for (int i = 0; i < sp.getPiecesCount(); i++) { + var piece = sp.getPieces(i); + tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); + minScore = Math.min(piece.getScore(), minScore); + maxScore = Math.max(piece.getScore(), maxScore); + } + this.minScore = minScore; + this.maxScore = maxScore; + } + catch (IOException e) { + throw new IllegalArgumentException("Could not read a SentencePiece model from '" + path + "'", e); + } + } + + } + + public static class Builder { + + private final Map<Language, Path> models = new HashMap<>(); + private boolean collapseUnknowns = true; + private Scoring scoring = Scoring.fewestSegments; + + public void addModel(Language language, Path model) { + models.put(language, model); + } + + /** + * Adds the model that will be used if the language is unknown, OR only one model is specified. + * The same as addModel(Language.UNKNOWN, model). + */ + public Builder addDefaultModel(Path model) { + addModel(Language.UNKNOWN, model); + return this; + } + public Map<Language, Path> getModels() { return models; } + + /** + * Sets whether consecutive unknown character should be collapsed into one large unknown token (default) + * or be returned as single character tokens. + */ + public Builder setCollapseUnknowns(boolean collapseUnknowns) { + this.collapseUnknowns = collapseUnknowns; + return this; + } + public boolean getCollapseUnknowns() { return collapseUnknowns; } + + /** + * Sets the scoring strategy to use when picking a segmentation. Default: fewestTokens. + */ + public Builder setScoring(Scoring scoring) { + this.scoring = scoring; + return this; + } + public Scoring getScoring() { return scoring; } + + public SentencePieceEncoder build() { + if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); + return new SentencePieceEncoder(this); + } + + } + +} diff --git a/linguistics/src/main/java/com/yahoo/language/simple/SimpleToken.java b/linguistics/src/main/java/com/yahoo/language/simple/SimpleToken.java index 7b63650fa94..e72497ebb82 100644 --- a/linguistics/src/main/java/com/yahoo/language/simple/SimpleToken.java +++ b/linguistics/src/main/java/com/yahoo/language/simple/SimpleToken.java @@ -17,16 +17,17 @@ public class SimpleToken implements Token { private final String orig; private TokenType type = TokenType.UNKNOWN; private TokenScript script = TokenScript.UNKNOWN; - private String tokenString = null; + private String tokenString; private boolean specialToken = false; private long offset = 0; public SimpleToken(String orig) { - this.orig = orig; + this(orig, null); } public SimpleToken(String orig, String tokenString) { this.orig = orig; + this.tokenString = tokenString; } @Override diff --git a/linguistics/src/main/protobuf/sentencepiece_model.proto b/linguistics/src/main/protobuf/sentencepiece_model.proto new file mode 100644 index 00000000000..39626aede53 --- /dev/null +++ b/linguistics/src/main/protobuf/sentencepiece_model.proto @@ -0,0 +1,310 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +syntax = "proto2"; + +// TODO(taku): Needs to use LITE RUNTIME in OSS release. +option optimize_for = LITE_RUNTIME; + +package sentencepiece; + +// TrainerSpec encodes a various parameters for SentencePiece training. +message TrainerSpec { + /////////////////////////////////////////////////////////////////// + // General parameters + // + // Input corpus files. + // Trainer accepts the following two formats: + // A) Monolingual: plain text, one sentence per line. + // B) Bilingual: TSV, source sentence <tab> target sentence + // When bilingual data is passed, shared vocabulary model is built. + // Note that the input file must be raw corpus, not a preprocessed corpus. + // Trainer only loads the first `input_sentence_size` sentences specified + // with this parameter. + repeated string input = 1; + + // Input corpus format: + // "text": one-sentence-per-line text format (default) + // "tsv": sentence <tab> freq + optional string input_format = 7; + + // Output model file prefix. + // <model_prefix>.model and <model_prefix>.vocab are generated. + optional string model_prefix = 2; + + // Model type. only have UNIGRAM now. + enum ModelType { + UNIGRAM = 1; // Unigram language model with dynamic algorithm + BPE = 2; // Byte Pair Encoding + WORD = 3; // Delimitered by whitespace. + CHAR = 4; // tokenizes into character sequence + } + optional ModelType model_type = 3 [default = UNIGRAM]; + + // Vocabulary size. 8k is the default size. + optional int32 vocab_size = 4 [default = 8000]; + + // List of the languages this model can accept. + // Since the model is language-agnostic, this field is used as a reference. + repeated string accept_language = 5; + + // Size of self-test samples, which are encoded in the model file. + optional int32 self_test_sample_size = 6 [default = 0]; + + /////////////////////////////////////////////////////////////////// + // Training parameters. + // + // Uses characters which cover the corpus with the ratio of `chars_coverage`. + // This parameter determines the set of basic Alphabet of sentence piece. + // 1.0 - `chars_coverage` characters are treated as UNK. + // See also required_chars field. + optional float character_coverage = 10 [default = 0.9995]; + + // Maximum size of sentences the trainer loads from `input` parameter. + // Trainer simply loads the `input` files in sequence. + // It is better to shuffle the input corpus randomly. + optional uint64 input_sentence_size = 11 [default = 0]; + optional bool shuffle_input_sentence = 19 [default = true]; + + // Maximum size of sentences to make seed sentence pieces. + // Extended suffix array is constructed to extract frequent + // sub-strings from the corpus. This uses 20N working space, + // where N is the size of corpus. + optional int32 mining_sentence_size = 12 [deprecated = true]; + + // Maximum size of sentences to train sentence pieces. + optional int32 training_sentence_size = 13 [deprecated = true]; + + // The size of seed sentencepieces. + // `seed_sentencepiece_size` must be larger than `vocab_size`. + optional int32 seed_sentencepiece_size = 14 [default = 1000000]; + + // In every EM sub-iterations, keeps top + // `shrinking_factor` * `current sentencepieces size` with respect to + // the loss of the sentence piece. This value should be smaller than 1.0. + optional float shrinking_factor = 15 [default = 0.75]; + + // The maximum sentence length in byte. The sentences with the length + // larger than `max_sentence_length` is simply ignored. + // Longer input tends to bring the following risks: + // * Overflow during EM training (unigram language model only) + // * Performance drop because of O(n log n) cost in BPE. + optional int32 max_sentence_length = 18 [default = 4192]; + + // Number of threads in the training. + optional int32 num_threads = 16 [default = 16]; + + // Number of EM sub iterations. + optional int32 num_sub_iterations = 17 [default = 2]; + + /////////////////////////////////////////////////////////////////// + // SentencePiece parameters which control the shapes of sentence piece. + // + // Maximum length of sentencepiece. + optional int32 max_sentencepiece_length = 20 [default = 16]; + + // Uses Unicode script to split sentence pieces. + // When `split_by_unicode_script` is true, we do not allow sentence piece to + // include multiple Unicode scripts, e.g. "F1" is not a valid piece. + // Exception: CJ characters (Hiragana/Katakana/Han) are all handled + // as one script type, since Japanese word can consist of multiple scripts. + // This exception is always applied regardless of the accept-language + // parameter. + optional bool split_by_unicode_script = 21 [default = true]; + + // When `split_by_number` is true, put a boundary between number and + // non-number transition. If we want to treat "F1" is one token, set this flag + // to be false. + optional bool split_by_number = 23 [default = true]; + + // Use a white space to split sentence pieces. + // When `split_by_whitespace` is false, we may have the piece containing + // a white space in the middle. e.g., "in_the". + optional bool split_by_whitespace = 22 [default = true]; + + // Adds whitespace symbol (_) as a suffix instead of prefix. e.g., _hello => + // hello_. When `treat_whitespace_as_suffix` is true, + // NormalizerSpec::add_dummy_prefix will add the dummy whitespace to the end + // of sentence. + optional bool treat_whitespace_as_suffix = 24 [default = false]; + + // Allows pieces that only contain whitespaces instead of appearing only as + // prefix or suffix of other pieces. + optional bool allow_whitespace_only_pieces = 26 [default = false]; + + // Split all digits (0-9) into separate pieces. + optional bool split_digits = 25 [default = false]; + + /////////////////////////////////////////////////////////////////// + // Vocabulary management + // + // Defines control symbols used as an indicator to + // change the behavior of the decoder. <s> and </s> are pre-defined. + // We can use this field to encode various meta information, + // including language indicator in multilingual model. + // These symbols are not visible to users, but visible to + // the decoder. Note that when the input sentence contains control symbols, + // they are not treated as one token, but segmented into normal pieces. + // Control symbols must be inserted independently from the segmentation. + repeated string control_symbols = 30; + + // Defines user defined symbols. + // These symbols are added with extremely high score + // so they are always treated as one unique symbol in any context. + // Typical usage of user_defined_symbols is placeholder for named entities. + repeated string user_defined_symbols = 31; + + // Defines required characters. Each UTF8 character in this string is included + // in the character set regardless of character_coverage value. Unlike + // user_defined_symbols, these characters have scores based on the frequency + // on input sentences, and the model can form subwords using characters + // in this field. + optional string required_chars = 36; + + // Decomposes unknown pieces into UTF-8 bytes. + optional bool byte_fallback = 35 [default = false]; + + // When creating the vocabulary file, defines whether or not to additionally + // output the score for each piece. + optional bool vocabulary_output_piece_score = 32 [default = true]; + + // `vocab_size` is treated as hard limit. Crash if + // the model can not produce the vocab of size `vocab_size`, + // When `hard_vocab_limit` is false, vocab_size is treated + // as soft limit. Note that when model_type=char, + // always assumes hard_vocab_limit = false. + optional bool hard_vocab_limit = 33 [default = true]; + + // use all symbols for vocab extraction. This flag is valid + // if model type is either CHAR or WORD + optional bool use_all_vocab = 34 [default = false]; + + /////////////////////////////////////////////////////////////////// + // Reserved special meta tokens. + // * -1 is not used. + // * unk_id must not be -1. + // Id must starts with 0 and be contigous. + optional int32 unk_id = 40 [default = 0]; // <unk> + optional int32 bos_id = 41 [default = 1]; // <s> + optional int32 eos_id = 42 [default = 2]; // </s> + optional int32 pad_id = 43 [default = -1]; // <pad> (padding) + optional string unk_piece = 45 [default = "<unk>"]; + optional string bos_piece = 46 [default = "<s>"]; + optional string eos_piece = 47 [default = "</s>"]; + optional string pad_piece = 48 [default = "<pad>"]; + + // Encodes <unk> into U+2047 (DOUBLE QUESTION MARK), + // since this character can be useful both for user and + // developer. We can easily figure out that <unk> is emitted. + optional string unk_surface = 44 [default = " \xE2\x81\x87 "]; + + // Increase bit depth to allow unigram model training on large + // (>10M sentences) corpora. A Side-effect of enabling this flag + // is increased memory usage. + optional bool train_extremely_large_corpus = 49 [default = false]; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// NormalizerSpec encodes a various parameters for string normalizaiton +message NormalizerSpec { + // name of normalization rule. + optional string name = 1; + + // Pre-compiled normalization rule created by + // Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method. + // Usually this field is set by Builder::GetNormalizerSpec() method. + optional bytes precompiled_charsmap = 2; + + // Adds dummy whitespace at the beginning of text in order to + // treat "world" in "world" and "hello world" in the same way. + optional bool add_dummy_prefix = 3 [default = true]; + + // Removes leading, trailing, and duplicate internal whitespace. + optional bool remove_extra_whitespaces = 4 [default = true]; + + // Replaces whitespace with meta symbol. + // This field must be true to train sentence piece model. + optional bool escape_whitespaces = 5 [default = true]; + + // Custom normalization rule file in TSV format. + // https://github.com/google/sentencepiece/blob/master/doc/normalization.md + // This field is only used in SentencePieceTrainer::Train() method, which + // compiles the rule into the binary rule stored in `precompiled_charsmap`. + optional string normalization_rule_tsv = 6; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// Proto to store samples for self-testing. +message SelfTestData { + message Sample { + optional string input = 1; + optional string expected = 2; + } + repeated Sample samples = 1; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// ModelProto stores model parameters. +// SentencePieceProcessor is supposed to be self-contained. +// All settings/parameters which may change the behavior must be encoded +// in ModelProto. +message ModelProto { + message SentencePiece { + enum Type { + NORMAL = 1; // normal symbol + UNKNOWN = 2; // unknown symbol. only <unk> for now. + CONTROL = 3; // control symbols. </s>, <s>, <2ja> etc. + USER_DEFINED = 4; // user defined symbols. + // Typical usage of USER_DEFINED symbol + // is placeholder. + BYTE = 6; // byte symbols. Used when `byte_fallback` is true. + UNUSED = 5; // this piece is not used. + } + optional string piece = 1; // piece must not be empty. + optional float score = 2; + optional Type type = 3 [default = NORMAL]; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; + } + + // Sentence pieces with scores. + repeated SentencePiece pieces = 1; + + // Spec used to generate this model file. + optional TrainerSpec trainer_spec = 2; + + // Spec for text normalization. + optional NormalizerSpec normalizer_spec = 3; + + // Stores sample input and its expected segmentation to verify the model. + optional SelfTestData self_test_data = 4; + + // Spec for text de-normalization. + optional NormalizerSpec denormalizer_spec = 5; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +}
\ No newline at end of file diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java new file mode 100644 index 00000000000..70361f55750 --- /dev/null +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -0,0 +1,78 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.sentencepiece; + +import com.yahoo.language.Language; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; + +import static org.junit.Assert.assertArrayEquals; + +/** + * @author bratseth + */ +public class SentencePieceTest { + + @Test + public void testEnglishTokenization() throws IOException { + var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + tester.assertSegmented("h", "▁h"); + tester.assertSegmented("he", "▁he"); + tester.assertSegmented("hel", "▁hel"); + tester.assertSegmented("hello", "▁hel", "lo"); + tester.assertSegmented("hei", "▁he", "i"); + tester.assertSegmented("hei you", "▁he", "i", "▁you"); + tester.assertSegmented("hei you", "▁he", "i", "▁you"); + tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); + tester.assertSegmented("hello world!", "▁hel", "lo", "▁world", "!"); + tester.assertSegmented("Hello, world!", "▁", "H", "ello", ",", "▁world", "!"); + tester.assertSegmented("HELLO, world!", "▁", "HELLO", ",", "▁world", "!"); + tester.assertSegmented("KHJKJHHKJHHSH", "▁", "KHJKJHHKJHHSH"); + tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); + tester.assertSegmented(" hello ", "▁hel", "lo"); + tester.assertSegmented(")(/&#()/\"\")", "▁)", "(", "/", "&", "#", "(", ")", "/", "\"", "\")"); + tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁)", "(", "/", "&", "#", "(", "sm", "all", ")", "/", "\"", "in", "▁qu", "otes", "\")"); + tester.assertSegmented("x.400AS", "▁x", ".", "4", "00", "AS"); + tester.assertSegmented("A normal sentence. Yes one more.", "▁", "A", "▁normal", "▁sentence", ".", "▁", "Y", "es", "▁one", "▁more", "."); + tester.assertEncoded("hello, world!", 908, 1418, 9934, 501, 9960); + tester.assertEncoded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960); + } + + @Test + public void testJapaneseTokenization() throws IOException { + SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder(); + builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); + builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + var tester = new SentencePieceTester(builder); + tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); + } + + private static class SentencePieceTester { + + private final SentencePieceEncoder encoder; + + public SentencePieceTester(Path model) { + this(new SentencePieceEncoder.Builder().addDefaultModel(model)); + } + + public SentencePieceTester(SentencePieceEncoder.Builder builder) { + encoder = builder.build(); + } + + private void assertEncoded(String input, Integer ... expectedCodes) { + assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray()); + } + + private void assertSegmented(String input, String ... expectedSegments) { + assertSegmented(Language.UNKNOWN, input, expectedSegments); + } + private void assertSegmented(Language language, String input, String ... expectedSegments) { + assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray()); + } + + } + +} diff --git a/linguistics/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model b/linguistics/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model Binary files differnew file mode 100644 index 00000000000..89f93ef3517 --- /dev/null +++ b/linguistics/src/test/models/sentencepiece/en.wiki.bpe.vs10000.model diff --git a/linguistics/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model b/linguistics/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model Binary files differnew file mode 100644 index 00000000000..41c0688d9df --- /dev/null +++ b/linguistics/src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model |