aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-16 12:11:44 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-16 12:11:44 +0200
commit39a052f324aeeb4c0eb2d4313edf57ddbc4db2c7 (patch)
tree438897b88a0e053c49a49e3360a1b5accd29a885 /linguistics
parenta2afdafbffcbc09594fd629c65746ec253f180be (diff)
Use a result builder
Diffstat (limited to 'linguistics')
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java74
1 files changed, 53 insertions, 21 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
index 4bf808bec0c..9a43d22ca4b 100644
--- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
@@ -7,6 +7,8 @@ import com.google.inject.Inject;
import com.yahoo.io.IOUtils;
import com.yahoo.language.Language;
import com.yahoo.language.process.Segmenter;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import sentencepiece.SentencepieceModel;
import java.io.IOException;
@@ -75,8 +77,14 @@ public class SentencePieceEncoder implements Segmenter {
@Override
public List<String> segment(String rawInput, Language language) {
String input = normalize(rawInput);
- SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1];
- return segment(input, language, segmentEnds, (segmentStart, segmentEnd) -> input.substring(segmentStart, segmentEnd));
+ var resultBuilder = new ResultBuilder<List<String>>(new ArrayList<>()) {
+ public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) {
+ result().add(input.substring(segmentStart, segmentEnd));
+ }
+ };
+ segment(input, language, resultBuilder);
+ Collections.reverse(resultBuilder.result());
+ return resultBuilder.result();
}
/**
@@ -87,18 +95,33 @@ public class SentencePieceEncoder implements Segmenter {
* @return the list of zero or more token ids resulting from segmenting the input text
*/
public List<Integer> encode(String rawInput, Language language) {
- String input = normalize(rawInput);
- SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1];
- return segment(input, language, segmentEnds, (segmentStart, segmentEnd) -> segmentEnds[segmentEnd].id);
+ var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) {
+ public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) {
+ result().add(segmentEnds[segmentEnd].id);
+ }
+ };
+ segment(normalize(rawInput), language, resultBuilder);
+ Collections.reverse(resultBuilder.result());
+ return resultBuilder.result();
+ }
+
+ /**
+ * Encodes directly to a tensor.
+ */
+ public Tensor encode(String input, Language language, TensorType type) {
+ if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
+ return null;
+ }
+ else {
+ throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type);
+ }
}
- private <ITEMTYPE> List<ITEMTYPE> segment(String input, Language language,
- SegmentEnd[] segmentEnds,
- BiFunction<Integer, Integer, ITEMTYPE> resultItemMapper) {
+ private <RESULTTYPE> void segment(String input, Language language, ResultBuilder<RESULTTYPE> resultBuilder) {
Model model = resolveFrom(language);
- float unknownScore = model.minScore - 10.0f;
- segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0);
+ SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1];
+ segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0);
int start = 0;
while (start < input.length()) { // segment from this position to the end of the text
Trie.Node node = model.tokens.root;
@@ -111,13 +134,12 @@ public class SentencePieceEncoder implements Segmenter {
addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds);
}
else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable
- addSegment(TokenType.unknown, 0, start, start + 1, unknownScore, segmentEnds);
+ addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds);
}
}
start++;
}
-
- return createResult(input, segmentEnds, resultItemMapper);
+ createResult(input, segmentEnds, resultBuilder);
}
private Model resolveFrom(Language language) {
@@ -137,18 +159,16 @@ public class SentencePieceEncoder implements Segmenter {
}
}
- private <ITEMTYPE> List<ITEMTYPE> createResult(String input, SegmentEnd[] segmentEnds,
- BiFunction<Integer, Integer, ITEMTYPE> resultItemMapper) {
- List<ITEMTYPE> result = new ArrayList<>();
+ private <RESULTTYPE> void createResult(String input, SegmentEnd[] segmentEnds, ResultBuilder<RESULTTYPE> resultBuilder) {
if (collapseUnknowns) {
int segmentEnd = input.length();
int collapsedSegmentEnd = segmentEnd;
while (segmentEnd > 0) {
if (segmentEnds[segmentEnd].type != TokenType.unknown ) {
if (collapsedSegmentEnd != segmentEnd) { // We have deferred an unknown collapsed segment
- result.add(resultItemMapper.apply(segmentEnd, collapsedSegmentEnd));
+ resultBuilder.add(segmentEnd, collapsedSegmentEnd, segmentEnds);
}
- result.add(resultItemMapper.apply(segmentEnds[segmentEnd].segmentStart, segmentEnd));
+ resultBuilder.add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds);
collapsedSegmentEnd = segmentEnds[segmentEnd].segmentStart;
}
segmentEnd = segmentEnds[segmentEnd].segmentStart;
@@ -157,12 +177,24 @@ public class SentencePieceEncoder implements Segmenter {
else {
int segmentEnd = input.length();
while (segmentEnd > 0) {
- result.add(resultItemMapper.apply(segmentEnds[segmentEnd].segmentStart, segmentEnd));
+ resultBuilder.add(segmentEnds[segmentEnd].segmentStart, segmentEnd, segmentEnds);
segmentEnd = segmentEnds[segmentEnd].segmentStart;
}
}
- Collections.reverse(result);
- return result;
+ }
+
+ private static abstract class ResultBuilder<RESULTTYPE> {
+
+ private RESULTTYPE result;
+
+ ResultBuilder(RESULTTYPE result) {
+ this.result = result;
+ }
+
+ abstract void add(int start, int end, SegmentEnd[] segmentEnds);
+
+ RESULTTYPE result() { return result; }
+
}
private final class SegmentEnd {