diff options
author | Jon Bratseth <bratseth@oath.com> | 2021-09-17 08:01:29 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-17 08:01:29 +0200 |
commit | 4d31a75b8a249593d0a3503669d3399b980c8be1 (patch) | |
tree | 050c7f725f4efdb736912c405008bd1d49bd782c /linguistics | |
parent | daab62042f34575d545dcd0b6fd100e232848c85 (diff) | |
parent | a0f2ddb8b759a928329996050c818f5a4fae90b0 (diff) |
Merge pull request #19180 from vespa-engine/bratseth/encoder-interface
Bratseth/encoder interface
Diffstat (limited to 'linguistics')
11 files changed, 374 insertions, 228 deletions
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 136d07721de..e8687b5c9f4 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -328,6 +328,20 @@ ], "fields": [] }, + "com.yahoo.language.process.Encoder": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public", + "interface", + "abstract" + ], + "methods": [ + "public abstract java.util.List encode(java.lang.String, com.yahoo.language.Language)", + "public abstract com.yahoo.tensor.Tensor encode(java.lang.String, com.yahoo.language.Language, com.yahoo.tensor.TensorType)" + ], + "fields": [] + }, "com.yahoo.language.process.GramSplitter$Gram": { "superClass": "java.lang.Object", "interfaces": [], @@ -701,6 +715,23 @@ ], "fields": [] }, + "com.yahoo.language.sentencepiece.Scoring": { + "superClass": "java.lang.Enum", + "interfaces": [], + "attributes": [ + "public", + "final", + "enum" + ], + "methods": [ + "public static com.yahoo.language.sentencepiece.Scoring[] values()", + "public static com.yahoo.language.sentencepiece.Scoring valueOf(java.lang.String)" + ], + "fields": [ + "public static final enum com.yahoo.language.sentencepiece.Scoring highestScore", + "public static final enum com.yahoo.language.sentencepiece.Scoring fewestSegments" + ] + }, "com.yahoo.language.sentencepiece.SentencePieceConfig$Builder": { "superClass": "java.lang.Object", "interfaces": [ @@ -846,33 +877,17 @@ "public java.util.Map getModels()", "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder setCollapseUnknowns(boolean)", "public boolean getCollapseUnknowns()", - "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder setScoring(com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring)", - "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring getScoring()", + "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder setScoring(com.yahoo.language.sentencepiece.Scoring)", + "public com.yahoo.language.sentencepiece.Scoring getScoring()", "public com.yahoo.language.sentencepiece.SentencePieceEncoder build()" ], "fields": [] }, - "com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring": { - "superClass": "java.lang.Enum", - "interfaces": [], - "attributes": [ - "public", - "final", - "enum" - ], - "methods": [ - "public static com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring[] values()", - "public static com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring valueOf(java.lang.String)" - ], - "fields": [ - "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring highestScore", - "public static final enum com.yahoo.language.sentencepiece.SentencePieceEncoder$Scoring fewestSegments" - ] - }, "com.yahoo.language.sentencepiece.SentencePieceEncoder": { "superClass": "java.lang.Object", "interfaces": [ - "com.yahoo.language.process.Segmenter" + "com.yahoo.language.process.Segmenter", + "com.yahoo.language.process.Encoder" ], "attributes": [ "public" @@ -886,5 +901,16 @@ "public java.lang.String normalize(java.lang.String)" ], "fields": [] + }, + "com.yahoo.language.sentencepiece.Trie": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>()" + ], + "fields": [] } }
\ No newline at end of file diff --git a/linguistics/src/main/java/com/yahoo/language/process/Encoder.java b/linguistics/src/main/java/com/yahoo/language/process/Encoder.java new file mode 100644 index 00000000000..91de16f669b --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/process/Encoder.java @@ -0,0 +1,39 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.process; + +import com.yahoo.language.Language; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.List; + +/** + * An encoder converts a text string to a tensor or list of tokens + * + * @author bratseth + */ +public interface Encoder { + + /** + * Encodes text into tokens in a list of ids. + * + * @param text the text to encode + * @param language the language of the text, or UNKNOWN to use language independent encoding + * @return the text encoded to a list of segment ids + * @throws IllegalArgumentException if the language is not supported by this encoder + */ + List<Integer> encode(String text, Language language); + + /** + * Encodes text into tokens in a tensor. + * The information contained in the encoding may depend on the tensor type. + * + * @param text the text to encode + * @param language the language of the text, or UNKNOWN to use language independent encoding + * @param tensorType the type of the ttensor to be returned + * @return the tex encoded into a tensor of the supplied type + * @throws IllegalArgumentException if the language or tensor type is not supported by this encoder + */ + Tensor encode(String text, Language language, TensorType tensorType); + +} diff --git a/linguistics/src/main/java/com/yahoo/language/process/Stemmer.java b/linguistics/src/main/java/com/yahoo/language/process/Stemmer.java index da8a73407ff..a2d0d0a84c9 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Stemmer.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Stemmer.java @@ -6,9 +6,9 @@ import com.yahoo.language.Language; import java.util.List; /** - * <p>Interface providing stemming of single words.</p> + * Interface providing stemming of single words. * - * @author <a href="mailto:mathiasm@yahoo-inc.com">Mathias Mølster Lidal</a> + * @author Mathias Mølster Lidal */ public interface Stemmer { diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java new file mode 100644 index 00000000000..74f300057dc --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Model.java @@ -0,0 +1,60 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +import com.yahoo.io.IOUtils; +import com.yahoo.language.Language; +import sentencepiece.SentencepieceModel; + +import java.io.IOException; +import java.nio.file.Path; + +/** + * A SentencePiece model + * + * @author bratseth + */ +final class Model { + + final Path source; + final Language language; + final float minScore; + final float maxScore; + final Trie tokens = new Trie(); + + Model(Language language, Path path) { + try { + this.source = path; + this.language = language; + var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile())); + float minScore = Float.MAX_VALUE; + float maxScore = Float.MIN_VALUE; + for (int i = 0; i < sp.getPiecesCount(); i++) { + var piece = sp.getPieces(i); + tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); + minScore = Math.min(piece.getScore(), minScore); + maxScore = Math.max(piece.getScore(), maxScore); + } + this.minScore = minScore; + this.maxScore = maxScore; + } catch (IOException e) { + throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e); + } + } + + private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) { + switch (type) { + case USER_DEFINED : return TokenType.userDefined; + case UNKNOWN : return TokenType.unknown; + case NORMAL : return TokenType.text; + case CONTROL : return TokenType.control; + case UNUSED : return TokenType.unused; + default : throw new IllegalArgumentException("Unknkown token type " + type); + } + } + + @Override + public String toString() { + return "SentencePiece model for " + language + ": '" + source + "'"; + } + +} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java new file mode 100644 index 00000000000..2141505374c --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/ResultBuilder.java @@ -0,0 +1,47 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * Builds a result from a sentencepiece tokenization by being called for each segment in reverse + * + * @param <RESULTTYPE> the type of result this produces + * @author bratseth + */ +abstract class ResultBuilder<RESULTTYPE> { + + private final RESULTTYPE result; + + ResultBuilder(RESULTTYPE result) { + this.result = result; + } + + /** Called for each segment, starting from the last and working backwards */ + abstract void add(int start, int end, SentencePieceAlgorithm.SegmentEnd[] segmentEnds); + + RESULTTYPE result() {return result;} + + void build(String input, SentencePieceAlgorithm.SegmentEnd[] segmentEnds, boolean collapseUnknowns) { + if (collapseUnknowns) { + int segmentEnd = input.length(); + int collapsedSegmentEnd = segmentEnd; + while (segmentEnd > 0) { + if (segmentEnds[segmentEnd].type != TokenType.unknown ) { + if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment + add(segmentEnd, collapsedSegmentEnd, segmentEnds); + } + add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); + collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart; + } + segmentEnd = segmentEnds[segmentEnd].segmentStart; + } + } + else { + int segmentEnd = input.length(); + while (segmentEnd > 0) { + add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); + segmentEnd = segmentEnds[segmentEnd].segmentStart; + } + } + } + +} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java new file mode 100644 index 00000000000..6c8560abee7 --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Scoring.java @@ -0,0 +1,17 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * The scoring strategy to use for picking segments + * + * @author bratseth + */ +public enum Scoring { + + /** Find the segmentation that has the highest score */ + highestScore, + + /** Find the segmentation that has the fewest segments, resolve ties by score sum */ + fewestSegments + +} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java new file mode 100644 index 00000000000..1659e3c0fa7 --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java @@ -0,0 +1,90 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * SentencePiece algorithm implementation + * + * @author bratseth + */ +class SentencePieceAlgorithm { + + // TODO: Support characters beyond BMP + + static final char spaceSymbol = '▁'; + + private final boolean collapseUnknowns; + private final Scoring scoring; + + SentencePieceAlgorithm(boolean collapseUnknowns, Scoring scoring) { + this.collapseUnknowns = collapseUnknowns; + this.scoring = scoring; + } + + public <RESULTTYPE> void segment(String input, ResultBuilder<RESULTTYPE> resultBuilder, Model model) { + SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; + segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0); + int start = 0; + while (start < input.length()) { // segment from this position to the end of the text + Trie.Node node = model.tokens.root; + int characterPosition = start; + while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position + node = node.children.get(input.charAt(characterPosition++)); + int length = characterPosition - start; + if (node != null && node.isToken() && node.type != TokenType.unused) { + float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score; + addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds); + } + else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable + addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds); + } + } + start++; + } + resultBuilder.build(input, segmentEnds, collapseUnknowns); + } + + private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) { + if (segmentEnds[end] == null || + segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) { + segmentEnds[end] = new SegmentEnd(type, id, + segmentEnds[start].pathScoreSum + score, + segmentEnds[start].pathSegmentCount + 1, + start); + } + } + + final class SegmentEnd { + + final TokenType type; + final int id; + final float pathScoreSum; + final int pathSegmentCount; + final int segmentStart; + + SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) { + this.type = type; + this.id = id; + this.pathScoreSum = pathScoreSum; + this.pathSegmentCount = pathSegmentCount; + this.segmentStart = segmentStart; + } + + public float score() { + switch (scoring) { + case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum; + case highestScore: return pathScoreSum; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + public float scoreWith(float additionalSegmentScore) { + switch (scoring) { + case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore ); + case highestScore: return pathScoreSum + additionalSegmentScore; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + } + +} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java index c7b131cc439..b6659ebeaa3 100644 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -1,18 +1,15 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.language.sentencepiece; import com.google.common.annotations.Beta; import com.google.inject.Inject; -import com.yahoo.io.IOUtils; import com.yahoo.language.Language; +import com.yahoo.language.process.Encoder; import com.yahoo.language.process.Segmenter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import sentencepiece.SentencepieceModel; -import java.io.IOException; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; @@ -30,34 +27,19 @@ import java.util.stream.Collectors; * @author bratseth */ @Beta -public class SentencePieceEncoder implements Segmenter { - - // TODO: Support characters beyond BMP - enum TokenType { text, control, userDefined, unknown, unused } - - /** The scoring strategy to use for picking segments */ - public enum Scoring { - /** Find the segmentation that has the highest score */ - highestScore, - /** Find the segmentation that has the fewest segments, resolve ties by score sum */ - fewestSegments - } - - private static final char spaceSymbol = '▁'; - - private final boolean collapseUnknowns; - private final Scoring scoring; +public class SentencePieceEncoder implements Segmenter, Encoder { private final Map<Language, Model> models; + private final SentencePieceAlgorithm algorithm; + @Inject public SentencePieceEncoder(SentencePieceConfig config) { this(new Builder(config)); } public SentencePieceEncoder(Builder builder) { - collapseUnknowns = builder.getCollapseUnknowns(); - scoring = builder.getScoring(); + algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring()); models = builder.getModels().entrySet() .stream() @@ -78,7 +60,7 @@ public class SentencePieceEncoder implements Segmenter { public List<String> segment(String rawInput, Language language) { String input = normalize(rawInput); var resultBuilder = new ResultBuilder<List<String>>(new ArrayList<>()) { - public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) { + public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { result().add(input.substring(segmentStart, segmentEnd)); } }; @@ -94,9 +76,10 @@ public class SentencePieceEncoder implements Segmenter { * @param language the model to use, or Language.UNKNOWN to use the default model if any * @return the list of zero or more token ids resulting from segmenting the input text */ + @Override public List<Integer> encode(String rawInput, Language language) { var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) { - public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) { + public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { result().add(segmentEnds[segmentEnd].id); } }; @@ -106,8 +89,18 @@ public class SentencePieceEncoder implements Segmenter { } /** - * Encodes directly to a tensor. + * <p>Encodes directly to a tensor.</p> + * + * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order + * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small + * it will be truncated.</p> + * + * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token + * position as value.</p> + * + * <p>If the tensor is any other type IllegalArgumentException is thrown.</p> */ + @Override public Tensor encode(String rawInput, Language language, TensorType type) { if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { // Build to a list first since we can't reverse a tensor builder @@ -136,29 +129,10 @@ public class SentencePieceEncoder implements Segmenter { } } - private <RESULTTYPE> void segment(String input, Language language, ResultBuilder<RESULTTYPE> resultBuilder) { + private <RESULTTYPE> void segment(String input, Language language, + ResultBuilder<RESULTTYPE> resultBuilder) { Model model = resolveFrom(language); - - SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; - segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0); - int start = 0; - while (start < input.length()) { // segment from this position to the end of the text - Trie.Node node = model.tokens.root; - int characterPosition = start; - while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position - node = node.children.get(input.charAt(characterPosition++)); - int length = characterPosition - start; - if (node != null && node.isToken() && node.type != TokenType.unused) { - float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score; - addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds); - } - else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable - addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds); - } - } - start++; - } - createResult(input, segmentEnds, resultBuilder); + algorithm.segment(input, resultBuilder, model); } private Model resolveFrom(Language language) { @@ -168,88 +142,6 @@ public class SentencePieceEncoder implements Segmenter { throw new IllegalArgumentException("No SentencePiece model for language " + language + " is configured"); } - private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) { - if (segmentEnds[end] == null || - segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) { - segmentEnds[end] = new SegmentEnd(type, id, - segmentEnds[start].pathScoreSum + score, - segmentEnds[start].pathSegmentCount + 1, - start); - } - } - - private <RESULTTYPE> void createResult(String input, SegmentEnd[] segmentEnds, ResultBuilder<RESULTTYPE> resultBuilder) { - if (collapseUnknowns) { - int segmentEnd = input.length(); - int collapsedSegmentEnd = segmentEnd; - while (segmentEnd > 0) { - if (segmentEnds[segmentEnd].type != TokenType.unknown ) { - if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment - resultBuilder.add(segmentEnd, collapsedSegmentEnd, segmentEnds); - } - resultBuilder.add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); - collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart; - } - segmentEnd = segmentEnds[segmentEnd].segmentStart; - } - } - else { - int segmentEnd = input.length(); - while (segmentEnd > 0) { - resultBuilder.add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); - segmentEnd = segmentEnds[segmentEnd].segmentStart; - } - } - } - - private static abstract class ResultBuilder<RESULTTYPE> { - - private final RESULTTYPE result; - - ResultBuilder(RESULTTYPE result) { - this.result = result; - } - - abstract void add(int start, int end, SegmentEnd[] segmentEnds); - - RESULTTYPE result() { return result; } - - } - - private final class SegmentEnd { - - final TokenType type; - final int id; - final float pathScoreSum; - final int pathSegmentCount; - final int segmentStart; - - SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) { - this.type = type; - this.id = id; - this.pathScoreSum = pathScoreSum; - this.pathSegmentCount = pathSegmentCount; - this.segmentStart = segmentStart; - } - - public float score() { - switch (scoring) { - case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum; - case highestScore: return pathScoreSum; - default : throw new IllegalArgumentException("Unknown scoring " + scoring); - } - } - - public float scoreWith(float additionalSegmentScore) { - switch (scoring) { - case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore ); - case highestScore: return pathScoreSum + additionalSegmentScore; - default : throw new IllegalArgumentException("Unknown scoring " + scoring); - } - } - - } - public String normalize(String s) { StringBuilder b = new StringBuilder(s.length() + 1); boolean queuedSpace = true; // Always start by one space @@ -260,7 +152,7 @@ public class SentencePieceEncoder implements Segmenter { } else { if (queuedSpace) { - b.append(spaceSymbol); + b.append(SentencePieceAlgorithm.spaceSymbol); queuedSpace = false; } b.append(c); @@ -269,79 +161,6 @@ public class SentencePieceEncoder implements Segmenter { return b.toString(); } - private static TokenType toTokenType(SentencepieceModel.ModelProto.SentencePiece.Type type) { - switch (type) { - case USER_DEFINED : return TokenType.userDefined; - case UNKNOWN : return TokenType.unknown; - case NORMAL : return TokenType.text; - case CONTROL : return TokenType.control; - case UNUSED : return TokenType.unused; - default : throw new IllegalArgumentException("Unknkown token type " + type); - } - } - - private static class Trie { - - final Node root = new Node(); - - void add(TokenType type, int id, String word, float score) { - Node current = root; - for (char l : word.toCharArray()) - current = current.children.computeIfAbsent(l, c -> new Node()); - current.type = type; - current.id = id; - current.score = score; - } - - static class Node { - - Integer id; - TokenType type; - Float score; - private final Map<Character, Node> children = new HashMap<>(); - - boolean isToken() { return type != null; } - - } - - } - - private static final class Model { - - final Path source; - final Language language; - final float minScore; - final float maxScore; - final Trie tokens = new Trie(); - - Model(Language language, Path path) { - try { - this.source = path; - this.language = language; - var sp = SentencepieceModel.ModelProto.parseFrom(IOUtils.readFileBytes(path.toFile())); - float minScore = Float.MAX_VALUE; - float maxScore = Float.MIN_VALUE; - for (int i = 0; i < sp.getPiecesCount(); i++) { - var piece = sp.getPieces(i); - tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); - minScore = Math.min(piece.getScore(), minScore); - maxScore = Math.max(piece.getScore(), maxScore); - } - this.minScore = minScore; - this.maxScore = maxScore; - } - catch (IOException e) { - throw new IllegalArgumentException("Could not read a SentencePiece model from " + path, e); - } - } - - @Override - public String toString() { - return "SentencePiece model for " + language + ": '" + source + "'"; - } - - } - public static class Builder { private final Map<Language, Path> models = new HashMap<>(); diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java new file mode 100644 index 00000000000..782030a8e4d --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/TokenType.java @@ -0,0 +1,13 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * SentencePiece token types + * + * @author bratseth + */ +enum TokenType { + + text, control, userDefined, unknown, unused + +} diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java new file mode 100644 index 00000000000..f3287a49517 --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/Trie.java @@ -0,0 +1,36 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +import java.util.HashMap; +import java.util.Map; + +/** + * A simple trie for sentencepiece token lookup + * + * @author bratseth + */ +public class Trie { + + final Node root = new Node(); + + void add(TokenType type, int id, String word, float score) { + Node current = root; + for (char l : word.toCharArray()) + current = current.children.computeIfAbsent(l, c -> new Node()); + current.type = type; + current.id = id; + current.score = score; + } + + static class Node { + + Integer id; + TokenType type; + Float score; + final Map<Character, Node> children = new HashMap<>(); + + boolean isToken() { return type != null; } + + } + +} diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index 5b77324a6fc..d60d7386d4b 100644 --- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -3,7 +3,6 @@ package com.yahoo.language.sentencepiece; import com.yahoo.language.Language; -import com.yahoo.tensor.Tensor; import org.junit.Test; import java.io.File; @@ -69,7 +68,7 @@ public class SentencePieceTest { public void testHighestScore() { var tester = new SentencePieceTester(new SentencePieceEncoder.Builder() .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) - .setScoring(SentencePieceEncoder.Scoring.highestScore)); + .setScoring(Scoring.highestScore)); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); tester.assertSegmented("hel", "▁h", "el"); |