diff options
Diffstat (limited to 'linguistics/src/main/java/com/yahoo/language')
-rw-r--r-- | linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java | 74 |
1 files changed, 53 insertions, 21 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java index 4bf808bec0c..9a43d22ca4b 100644 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -7,6 +7,8 @@ import com.google.inject.Inject; import com.yahoo.io.IOUtils; import com.yahoo.language.Language; import com.yahoo.language.process.Segmenter; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import sentencepiece.SentencepieceModel; import java.io.IOException; @@ -75,8 +77,14 @@ public class SentencePieceEncoder implements Segmenter { @Override public List<String> segment(String rawInput, Language language) { String input = normalize(rawInput); - SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; - return segment(input, language, segmentEnds, (segmentStart, segmentEnd) -> input.substring(segmentStart, segmentEnd)); + var resultBuilder = new ResultBuilder<List<String>>(new ArrayList<>()) { + public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) { + result().add(input.substring(segmentStart, segmentEnd)); + } + }; + segment(input, language, resultBuilder); + Collections.reverse(resultBuilder.result()); + return resultBuilder.result(); } /** @@ -87,18 +95,33 @@ public class SentencePieceEncoder implements Segmenter { * @return the list of zero or more token ids resulting from segmenting the input text */ public List<Integer> encode(String rawInput, Language language) { - String input = normalize(rawInput); - SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; - return segment(input, language, segmentEnds, (segmentStart, segmentEnd) -> segmentEnds[segmentEnd].id); + var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) { + public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) { + result().add(segmentEnds[segmentEnd].id); + } + }; + segment(normalize(rawInput), language, resultBuilder); + Collections.reverse(resultBuilder.result()); + return resultBuilder.result(); + } + + /** + * Encodes directly to a tensor. + */ + public Tensor encode(String input, Language language, TensorType type) { + if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { + return null; + } + else { + throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type); + } } - private <ITEMTYPE> List<ITEMTYPE> segment(String input, Language language, - SegmentEnd[] segmentEnds, - BiFunction<Integer, Integer, ITEMTYPE> resultItemMapper) { + private <RESULTTYPE> void segment(String input, Language language, ResultBuilder<RESULTTYPE> resultBuilder) { Model model = resolveFrom(language); - float unknownScore = model.minScore - 10.0f; - segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0); + SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; + segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0); int start = 0; while (start < input.length()) { // segment from this position to the end of the text Trie.Node node = model.tokens.root; @@ -111,13 +134,12 @@ public class SentencePieceEncoder implements Segmenter { addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds); } else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable - addSegment(TokenType.unknown, 0, start, start + 1, unknownScore, segmentEnds); + addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds); } } start++; } - - return createResult(input, segmentEnds, resultItemMapper); + createResult(input, segmentEnds, resultBuilder); } private Model resolveFrom(Language language) { @@ -137,18 +159,16 @@ public class SentencePieceEncoder implements Segmenter { } } - private <ITEMTYPE> List<ITEMTYPE> createResult(String input, SegmentEnd[] segmentEnds, - BiFunction<Integer, Integer, ITEMTYPE> resultItemMapper) { - List<ITEMTYPE> result = new ArrayList<>(); + private <RESULTTYPE> void createResult(String input, SegmentEnd[] segmentEnds, ResultBuilder<RESULTTYPE> resultBuilder) { if (collapseUnknowns) { int segmentEnd = input.length(); int collapsedSegmentEnd = segmentEnd; while (segmentEnd > 0) { if (segmentEnds[segmentEnd].type != TokenType.unknown ) { if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment - result.add(resultItemMapper.apply(segmentEnd, collapsedSegmentEnd)); + resultBuilder.add(segmentEnd, collapsedSegmentEnd, segmentEnds); } - result.add(resultItemMapper.apply(segmentEnds[segmentEnd].segmentStart, segmentEnd)); + resultBuilder.add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart; } segmentEnd = segmentEnds[segmentEnd].segmentStart; @@ -157,12 +177,24 @@ public class SentencePieceEncoder implements Segmenter { else { int segmentEnd = input.length(); while (segmentEnd > 0) { - result.add(resultItemMapper.apply(segmentEnds[segmentEnd].segmentStart, segmentEnd)); + resultBuilder.add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds); segmentEnd = segmentEnds[segmentEnd].segmentStart; } } - Collections.reverse(result); - return result; + } + + private static abstract class ResultBuilder<RESULTTYPE> { + + private RESULTTYPE result; + + ResultBuilder(RESULTTYPE result) { + this.result = result; + } + + abstract void add(int start, int end, SegmentEnd[] segmentEnds); + + RESULTTYPE result() { return result; } + } private final class SegmentEnd { |