diff options
Diffstat (limited to 'linguistics/src/main/java/com/yahoo/language/sentencepiece')
8 files changed, 0 insertions, 490 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java deleted file mode 100644 index 74f300057dc..00000000000 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.sentencepiece; - -import com.yahoo.io.IOUtils; -import com.yahoo.language.Language; -import sentencepiece.SentencepieceModel; - -import java.io.IOException; -import java.nio.file.Path; - -/** - * A SentencePiece model - * - * @author bratseth - */ -final class Model { - - final Path source; - final Language language; - final float minScore; - final float maxScore; - final Trie tokens = new Trie(); - - Model(Language language, Path path) { - try { - this.source = path; - this.language = language; - var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile())); - float minScore = Float.MAX_VALUE; - float maxScore = Float.MIN_VALUE; - for (int i = 0; i < sp.getPiecesCount(); i++) { - var piece = sp.getPieces(i); - tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); - minScore = Math.min(piece.getScore(), minScore); - maxScore = Math.max(piece.getScore(), maxScore); - } - this.minScore = minScore; - this.maxScore = maxScore; - } catch (IOException e) { - throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e); - } - } - - private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) { - switch (type) { - case USER_DEFINED : return TokenType.userDefined; - case UNKNOWN : return TokenType.unknown; - case NORMAL : return TokenType.text; - case CONTROL : return TokenType.control; - case UNUSED : return TokenType.unused; - default : throw new IllegalArgumentException("Unknkown token type " + type); - } - } - - @Override - public String toString() { - return "SentencePiece model for " + language + ": '" + source + "'"; - } - -} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java deleted file mode 100644 index 2141505374c..00000000000 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.sentencepiece; - -/** - * Builds a result from a sentencepiece tokenization by being called for each segment in reverse - * - * @param <RESULTTYPE> the type of result this produces - * @author bratseth - */ -abstract class ResultBuilder<RESULTTYPE> { - - private final RESULTTYPE result; - - ResultBuilder(RESULTTYPE result) { - this.result = result; - } - - /** Called for each segment, starting from the last and working backwards */ - abstract void add(int start, int end, SentencePieceAlgorithm.SegmentEnd[] segmentEnds); - - RESULTTYPE result() {return result;} - - void build(String input, SentencePieceAlgorithm.SegmentEnd[] segmentEnds, boolean collapseUnknowns) { - if (collapseUnknowns) { - int segmentEnd = input.length(); - int collapsedSegmentEnd = segmentEnd; - while (segmentEnd > 0) { - if (segmentEnds[segmentEnd].type != TokenType.unknown ) { - if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment - add(segmentEnd, collapsedSegmentEnd, segmentEnds); - } - add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); - collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart; - } - segmentEnd = segmentEnds[segmentEnd].segmentStart; - } - } - else { - int segmentEnd = input.length(); - while (segmentEnd > 0) { - add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); - segmentEnd = segmentEnds[segmentEnd].segmentStart; - } - } - } - -} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java deleted file mode 100644 index 6c8560abee7..00000000000 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.sentencepiece; - -/** - * The scoring strategy to use for picking segments - * - * @author bratseth - */ -public enum Scoring { - - /** Find the segmentation that has the highest score */ - highestScore, - - /** Find the segmentation that has the fewest segments, resolve ties by score sum */ - fewestSegments - -} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java deleted file mode 100644 index 1659e3c0fa7..00000000000 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.sentencepiece; - -/** - * SentencePiece algorithm implementation - * - * @author bratseth - */ -class SentencePieceAlgorithm { - - // TODO: Support characters beyond BMP - - static final char spaceSymbol = '▁'; - - private final boolean collapseUnknowns; - private final Scoring scoring; - - SentencePieceAlgorithm(boolean collapseUnknowns, Scoring scoring) { - this.collapseUnknowns = collapseUnknowns; - this.scoring = scoring; - } - - public <RESULTTYPE> void segment(String input, ResultBuilder<RESULTTYPE> resultBuilder, Model model) { - SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; - segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0); - int start = 0; - while (start < input.length()) { // segment from this position to the end of the text - Trie.Node node = model.tokens.root; - int characterPosition = start; - while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position - node = node.children.get(input.charAt(characterPosition++)); - int length = characterPosition - start; - if (node != null && node.isToken() && node.type != TokenType.unused) { - float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score; - addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds); - } - else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable - addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds); - } - } - start++; - } - resultBuilder.build(input, segmentEnds, collapseUnknowns); - } - - private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) { - if (segmentEnds[end] == null || - segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) { - segmentEnds[end] = new SegmentEnd(type, id, - segmentEnds[start].pathScoreSum + score, - segmentEnds[start].pathSegmentCount + 1, - start); - } - } - - final class SegmentEnd { - - final TokenType type; - final int id; - final float pathScoreSum; - final int pathSegmentCount; - final int segmentStart; - - SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) { - this.type = type; - this.id = id; - this.pathScoreSum = pathScoreSum; - this.pathSegmentCount = pathSegmentCount; - this.segmentStart = segmentStart; - } - - public float score() { - switch (scoring) { - case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum; - case highestScore: return pathScoreSum; - default : throw new IllegalArgumentException("Unknown scoring " + scoring); - } - } - - public float scoreWith(float additionalSegmentScore) { - switch (scoring) { - case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore ); - case highestScore: return pathScoreSum + additionalSegmentScore; - default : throw new IllegalArgumentException("Unknown scoring " + scoring); - } - } - - } - -} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java deleted file mode 100644 index b6659ebeaa3..00000000000 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.sentencepiece; - -import com.google.common.annotations.Beta; -import com.google.inject.Inject; -import com.yahoo.language.Language; -import com.yahoo.language.process.Encoder; -import com.yahoo.language.process.Segmenter; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; - -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * Integration with https://github.com/google/sentencepiece - * through http://docs.djl.ai/extensions/sentencepiece/index.html - * - * SentencePiece is a language-agnostic tokenizer for neural nets. - * - * @author bratseth - */ -@Beta -public class SentencePieceEncoder implements Segmenter, Encoder { - - private final Map<Language, Model> models; - - private final SentencePieceAlgorithm algorithm; - - @Inject - public SentencePieceEncoder(SentencePieceConfig config) { - this(new Builder(config)); - } - - public SentencePieceEncoder(Builder builder) { - algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring()); - - models = builder.getModels().entrySet() - .stream() - .map(e -> new Model(e.getKey(), e.getValue())) - .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); - if (models.isEmpty()) - throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured"); - } - - /** - * Segments the given text into token segments using the SentencePiece algorithm - * - * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. - * @param language the model to use, or Language.UNKNOWN to use the default model if any - * @return the list of zero or more tokens resulting from segmenting the input text - */ - @Override - public List<String> segment(String rawInput, Language language) { - String input = normalize(rawInput); - var resultBuilder = new ResultBuilder<List<String>>(new ArrayList<>()) { - public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { - result().add(input.substring(segmentStart, segmentEnd)); - } - }; - segment(input, language, resultBuilder); - Collections.reverse(resultBuilder.result()); - return resultBuilder.result(); - } - - /** - * Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids. - * - * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. - * @param language the model to use, or Language.UNKNOWN to use the default model if any - * @return the list of zero or more token ids resulting from segmenting the input text - */ - @Override - public List<Integer> encode(String rawInput, Language language) { - var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) { - public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { - result().add(segmentEnds[segmentEnd].id); - } - }; - segment(normalize(rawInput), language, resultBuilder); - Collections.reverse(resultBuilder.result()); - return resultBuilder.result(); - } - - /** - * <p>Encodes directly to a tensor.</p> - * - * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order - * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small - * it will be truncated.</p> - * - * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token - * position as value.</p> - * - * <p>If the tensor is any other type IllegalArgumentException is thrown.</p> - */ - @Override - public Tensor encode(String rawInput, Language language, TensorType type) { - if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { - // Build to a list first since we can't reverse a tensor builder - List<Integer> values = encode(rawInput, language); - - long maxSize = values.size(); - if (type.dimensions().get(0).size().isPresent()) - maxSize = Math.min(maxSize, type.dimensions().get(0).size().get()); - - Tensor.Builder builder = Tensor.Builder.of(type); - for (int i = 0; i < maxSize; i++) - builder.cell(values.get(i), i); - return builder.build(); - } - else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) { - // Build to a list first since we can't reverse a tensor builder - List<String> values = segment(rawInput, language); - - Tensor.Builder builder = Tensor.Builder.of(type); - for (int i = 0; i < values.size(); i++) - builder.cell(TensorAddress.ofLabels(values.get(i)), i); - return builder.build(); - } - else { - throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type); - } - } - - private <RESULTTYPE> void segment(String input, Language language, - ResultBuilder<RESULTTYPE> resultBuilder) { - Model model = resolveFrom(language); - algorithm.segment(input, resultBuilder, model); - } - - private Model resolveFrom(Language language) { - // Disregard language if there is default model - if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); - if (models.containsKey(language)) return models.get(language); - throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured"); - } - - public String normalize(String s) { - StringBuilder b = new StringBuilder(s.length() + 1); - boolean queuedSpace = true; // Always start by one space - for (int i = 0; i < s.length(); i++) { - char c = s.charAt(i); - if (s.charAt(i) == ' ') { - queuedSpace = true; - } - else { - if (queuedSpace) { - b.append(SentencePieceAlgorithm.spaceSymbol); - queuedSpace = false; - } - b.append(c); - } - } - return b.toString(); - } - - public static class Builder { - - private final Map<Language, Path> models = new HashMap<>(); - private boolean collapseUnknowns = true; - private Scoring scoring = Scoring.fewestSegments; - - public Builder() { - } - - private Builder(SentencePieceConfig config) { - collapseUnknowns = config.collapseUnknowns(); - scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments - : Scoring.highestScore; - for (SentencePieceConfig.Model model : config.model()) { - addModel(Language.fromLanguageTag(model.language()), model.path()); - } - } - - public void addModel(Language language, Path model) { - models.put(language, model); - } - - /** - * Adds the model that will be used if the language is unknown, OR only one model is specified. - * The same as addModel(Language.UNKNOWN, model). - */ - public Builder addDefaultModel(Path model) { - addModel(Language.UNKNOWN, model); - return this; - } - public Map<Language, Path> getModels() { return models; } - - /** - * Sets whether consecutive unknown character should be collapsed into one large unknown token (default) - * or be returned as single character tokens. - */ - public Builder setCollapseUnknowns(boolean collapseUnknowns) { - this.collapseUnknowns = collapseUnknowns; - return this; - } - public boolean getCollapseUnknowns() { return collapseUnknowns; } - - /** Sets the scoring strategy to use when picking a segmentation. Default: fewestSegments. */ - public Builder setScoring(Scoring scoring) { - this.scoring = scoring; - return this; - } - public Scoring getScoring() { return scoring; } - - public SentencePieceEncoder build() { - if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); - return new SentencePieceEncoder(this); - } - - } - -} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java deleted file mode 100644 index 782030a8e4d..00000000000 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.sentencepiece; - -/** - * SentencePiece token types - * - * @author bratseth - */ -enum TokenType { - - text, control, userDefined, unknown, unused - -} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java deleted file mode 100644 index 8e7c2db2ed3..00000000000 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.language.sentencepiece; - -import java.util.HashMap; -import java.util.Map; - -/** - * A simple trie for sentencepiece token lookups. - * - * @author bratseth - */ -class Trie { - - final Node root = new Node(); - - void add(TokenType type, int id, String word, float score) { - Node current = root; - for (char l : word.toCharArray()) - current = current.children.computeIfAbsent(l, c -> new Node()); - current.type = type; - current.id = id; - current.score = score; - } - - static class Node { - - Integer id; - TokenType type; - Float score; - final Map<Character, Node> children = new HashMap<>(); - - boolean isToken() { return type != null; } - - } - -} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/package-info.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/package-info.java deleted file mode 100644 index 4a8673705ec..00000000000 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/package-info.java +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package com.yahoo.language.sentencepiece; - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; |