From 1abb5adacbdbcfad7070243630164e4d31f68773 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Sat, 25 Sep 2021 14:50:33 +0000 Subject: Separate component from linguistics --- .../com/yahoo/language/sentencepiece/Model.java | 60 ++++++ .../language/sentencepiece/ResultBuilder.java | 47 +++++ .../com/yahoo/language/sentencepiece/Scoring.java | 17 ++ .../sentencepiece/SentencePieceAlgorithm.java | 90 +++++++++ .../sentencepiece/SentencePieceEncoder.java | 220 +++++++++++++++++++++ .../yahoo/language/sentencepiece/TokenType.java | 13 ++ .../com/yahoo/language/sentencepiece/Trie.java | 36 ++++ .../yahoo/language/sentencepiece/package-info.java | 7 + 8 files changed, 490 insertions(+) create mode 100644 linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java create mode 100644 linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java create mode 100644 linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java create mode 100644 linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java create mode 100644 linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java create mode 100644 linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java create mode 100644 linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java create mode 100644 linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java (limited to 'linguistics-components/src/main/java/com/yahoo') diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java new file mode 100644 index 00000000000..74f300057dc --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java @@ -0,0 +1,60 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +import com.yahoo.io.IOUtils; +import com.yahoo.language.Language; +import sentencepiece.SentencepieceModel; + +import java.io.IOException; +import java.nio.file.Path; + +/** + * A SentencePiece model + * + * @author bratseth + */ +final class Model { + + final Path source; + final Language language; + final float minScore; + final float maxScore; + final Trie tokens = new Trie(); + + Model(Language language, Path path) { + try { + this.source = path; + this.language = language; + var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile())); + float minScore = Float.MAX_VALUE; + float maxScore = Float.MIN_VALUE; + for (int i = 0; i < sp.getPiecesCount(); i++) { + var piece = sp.getPieces(i); + tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); + minScore = Math.min(piece.getScore(), minScore); + maxScore = Math.max(piece.getScore(), maxScore); + } + this.minScore = minScore; + this.maxScore = maxScore; + } catch (IOException e) { + throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e); + } + } + + private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) { + switch (type) { + case USER_DEFINED : return TokenType.userDefined; + case UNKNOWN : return TokenType.unknown; + case NORMAL : return TokenType.text; + case CONTROL : return TokenType.control; + case UNUSED : return TokenType.unused; + default : throw new IllegalArgumentException("Unknkown token type " + type); + } + } + + @Override + public String toString() { + return "SentencePiece model for " + language + ": '" + source + "'"; + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java new file mode 100644 index 00000000000..2141505374c --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java @@ -0,0 +1,47 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * Builds a result from a sentencepiece tokenization by being called for each segment in reverse + * + * @param the type of result this produces + * @author bratseth + */ +abstract class ResultBuilder { + + private final RESULTTYPE result; + + ResultBuilder(RESULTTYPE result) { + this.result = result; + } + + /** Called for each segment, starting from the last and working backwards */ + abstract void add(int start, int end, SentencePieceAlgorithm.SegmentEnd[] segmentEnds); + + RESULTTYPE result() {return result;} + + void build(String input, SentencePieceAlgorithm.SegmentEnd[] segmentEnds, boolean collapseUnknowns) { + if (collapseUnknowns) { + int segmentEnd = input.length(); + int collapsedSegmentEnd = segmentEnd; + while (segmentEnd > 0) { + if (segmentEnds[segmentEnd].type != TokenType.unknown ) { + if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment + add(segmentEnd, collapsedSegmentEnd, segmentEnds); + } + add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); + collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart; + } + segmentEnd = segmentEnds[segmentEnd].segmentStart; + } + } + else { + int segmentEnd = input.length(); + while (segmentEnd > 0) { + add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); + segmentEnd = segmentEnds[segmentEnd].segmentStart; + } + } + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java new file mode 100644 index 00000000000..6c8560abee7 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Scoring.java @@ -0,0 +1,17 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * The scoring strategy to use for picking segments + * + * @author bratseth + */ +public enum Scoring { + + /** Find the segmentation that has the highest score */ + highestScore, + + /** Find the segmentation that has the fewest segments, resolve ties by score sum */ + fewestSegments + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java new file mode 100644 index 00000000000..1659e3c0fa7 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java @@ -0,0 +1,90 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * SentencePiece algorithm implementation + * + * @author bratseth + */ +class SentencePieceAlgorithm { + + // TODO: Support characters beyond BMP + + static final char spaceSymbol = '▁'; + + private final boolean collapseUnknowns; + private final Scoring scoring; + + SentencePieceAlgorithm(boolean collapseUnknowns, Scoring scoring) { + this.collapseUnknowns = collapseUnknowns; + this.scoring = scoring; + } + + public void segment(String input, ResultBuilder resultBuilder, Model model) { + SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; + segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0); + int start = 0; + while (start < input.length()) { // segment from this position to the end of the text + Trie.Node node = model.tokens.root; + int characterPosition = start; + while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position + node = node.children.get(input.charAt(characterPosition++)); + int length = characterPosition - start; + if (node != null && node.isToken() && node.type != TokenType.unused) { + float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score; + addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds); + } + else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable + addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds); + } + } + start++; + } + resultBuilder.build(input, segmentEnds, collapseUnknowns); + } + + private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) { + if (segmentEnds[end] == null || + segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) { + segmentEnds[end] = new SegmentEnd(type, id, + segmentEnds[start].pathScoreSum + score, + segmentEnds[start].pathSegmentCount + 1, + start); + } + } + + final class SegmentEnd { + + final TokenType type; + final int id; + final float pathScoreSum; + final int pathSegmentCount; + final int segmentStart; + + SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) { + this.type = type; + this.id = id; + this.pathScoreSum = pathScoreSum; + this.pathSegmentCount = pathSegmentCount; + this.segmentStart = segmentStart; + } + + public float score() { + switch (scoring) { + case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum; + case highestScore: return pathScoreSum; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + public float scoreWith(float additionalSegmentScore) { + switch (scoring) { + case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore ); + case highestScore: return pathScoreSum + additionalSegmentScore; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java new file mode 100644 index 00000000000..b6659ebeaa3 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -0,0 +1,220 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +import com.google.common.annotations.Beta; +import com.google.inject.Inject; +import com.yahoo.language.Language; +import com.yahoo.language.process.Encoder; +import com.yahoo.language.process.Segmenter; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Integration with https://github.com/google/sentencepiece + * through http://docs.djl.ai/extensions/sentencepiece/index.html + * + * SentencePiece is a language-agnostic tokenizer for neural nets. + * + * @author bratseth + */ +@Beta +public class SentencePieceEncoder implements Segmenter, Encoder { + + private final Map models; + + private final SentencePieceAlgorithm algorithm; + + @Inject + public SentencePieceEncoder(SentencePieceConfig config) { + this(new Builder(config)); + } + + public SentencePieceEncoder(Builder builder) { + algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring()); + + models = builder.getModels().entrySet() + .stream() + .map(e -> new Model(e.getKey(), e.getValue())) + .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); + if (models.isEmpty()) + throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured"); + } + + /** + * Segments the given text into token segments using the SentencePiece algorithm + * + * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. + * @param language the model to use, or Language.UNKNOWN to use the default model if any + * @return the list of zero or more tokens resulting from segmenting the input text + */ + @Override + public List segment(String rawInput, Language language) { + String input = normalize(rawInput); + var resultBuilder = new ResultBuilder>(new ArrayList<>()) { + public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { + result().add(input.substring(segmentStart, segmentEnd)); + } + }; + segment(input, language, resultBuilder); + Collections.reverse(resultBuilder.result()); + return resultBuilder.result(); + } + + /** + * Segments the given text into token segments using the SentencePiece algorithm and returns the segment ids. + * + * @param rawInput the text to segment. Any sequence of BMP (Unicode-16 the True Unicode) is supported. + * @param language the model to use, or Language.UNKNOWN to use the default model if any + * @return the list of zero or more token ids resulting from segmenting the input text + */ + @Override + public List encode(String rawInput, Language language) { + var resultBuilder = new ResultBuilder>(new ArrayList<>()) { + public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { + result().add(segmentEnds[segmentEnd].id); + } + }; + segment(normalize(rawInput), language, resultBuilder); + Collections.reverse(resultBuilder.result()); + return resultBuilder.result(); + } + + /** + *

Encodes directly to a tensor.

+ * + *

If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order + * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small + * it will be truncated.

+ * + *

If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token + * position as value.

+ * + *

If the tensor is any other type IllegalArgumentException is thrown.

+ */ + @Override + public Tensor encode(String rawInput, Language language, TensorType type) { + if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { + // Build to a list first since we can't reverse a tensor builder + List values = encode(rawInput, language); + + long maxSize = values.size(); + if (type.dimensions().get(0).size().isPresent()) + maxSize = Math.min(maxSize, type.dimensions().get(0).size().get()); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int i = 0; i < maxSize; i++) + builder.cell(values.get(i), i); + return builder.build(); + } + else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) { + // Build to a list first since we can't reverse a tensor builder + List values = segment(rawInput, language); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int i = 0; i < values.size(); i++) + builder.cell(TensorAddress.ofLabels(values.get(i)), i); + return builder.build(); + } + else { + throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type); + } + } + + private void segment(String input, Language language, + ResultBuilder resultBuilder) { + Model model = resolveFrom(language); + algorithm.segment(input, resultBuilder, model); + } + + private Model resolveFrom(Language language) { + // Disregard language if there is default model + if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); + if (models.containsKey(language)) return models.get(language); + throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured"); + } + + public String normalize(String s) { + StringBuilder b = new StringBuilder(s.length() + 1); + boolean queuedSpace = true; // Always start by one space + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if (s.charAt(i) == ' ') { + queuedSpace = true; + } + else { + if (queuedSpace) { + b.append(SentencePieceAlgorithm.spaceSymbol); + queuedSpace = false; + } + b.append(c); + } + } + return b.toString(); + } + + public static class Builder { + + private final Map models = new HashMap<>(); + private boolean collapseUnknowns = true; + private Scoring scoring = Scoring.fewestSegments; + + public Builder() { + } + + private Builder(SentencePieceConfig config) { + collapseUnknowns = config.collapseUnknowns(); + scoring = config.scoring() == SentencePieceConfig.Scoring.fewestSegments ? Scoring.fewestSegments + : Scoring.highestScore; + for (SentencePieceConfig.Model model : config.model()) { + addModel(Language.fromLanguageTag(model.language()), model.path()); + } + } + + public void addModel(Language language, Path model) { + models.put(language, model); + } + + /** + * Adds the model that will be used if the language is unknown, OR only one model is specified. + * The same as addModel(Language.UNKNOWN, model). + */ + public Builder addDefaultModel(Path model) { + addModel(Language.UNKNOWN, model); + return this; + } + public Map getModels() { return models; } + + /** + * Sets whether consecutive unknown character should be collapsed into one large unknown token (default) + * or be returned as single character tokens. + */ + public Builder setCollapseUnknowns(boolean collapseUnknowns) { + this.collapseUnknowns = collapseUnknowns; + return this; + } + public boolean getCollapseUnknowns() { return collapseUnknowns; } + + /** Sets the scoring strategy to use when picking a segmentation. Default: fewestSegments. */ + public Builder setScoring(Scoring scoring) { + this.scoring = scoring; + return this; + } + public Scoring getScoring() { return scoring; } + + public SentencePieceEncoder build() { + if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); + return new SentencePieceEncoder(this); + } + + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java new file mode 100644 index 00000000000..782030a8e4d --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/TokenType.java @@ -0,0 +1,13 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * SentencePiece token types + * + * @author bratseth + */ +enum TokenType { + + text, control, userDefined, unknown, unused + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java new file mode 100644 index 00000000000..8e7c2db2ed3 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Trie.java @@ -0,0 +1,36 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +import java.util.HashMap; +import java.util.Map; + +/** + * A simple trie for sentencepiece token lookups. + * + * @author bratseth + */ +class Trie { + + final Node root = new Node(); + + void add(TokenType type, int id, String word, float score) { + Node current = root; + for (char l : word.toCharArray()) + current = current.children.computeIfAbsent(l, c -> new Node()); + current.type = type; + current.id = id; + current.score = score; + } + + static class Node { + + Integer id; + TokenType type; + Float score; + final Map children = new HashMap<>(); + + boolean isToken() { return type != null; } + + } + +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java new file mode 100644 index 00000000000..3f97277c489 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/package-info.java @@ -0,0 +1,7 @@ +// Copyright 2021 Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package com.yahoo.language.sentencepiece; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; -- cgit v1.2.3