diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-09-16 22:46:17 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-09-16 22:46:17 +0200 |
commit | 583101eec4032dda2310df146962e8471d70f188 (patch) | |
tree | e8f9f53f29d7c2bf82665d3256a7f333c696d1b8 /linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java | |
parent | daab62042f34575d545dcd0b6fd100e232848c85 (diff) |
Encoder interface
Diffstat (limited to 'linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java')
-rw-r--r-- | linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java index c7b131cc439..74ed79b267b 100644 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -6,6 +6,7 @@ import com.google.common.annotations.Beta; import com.google.inject.Inject; import com.yahoo.io.IOUtils; import com.yahoo.language.Language; +import com.yahoo.language.process.Encoder; import com.yahoo.language.process.Segmenter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -30,7 +31,7 @@ import java.util.stream.Collectors; * @author bratseth */ @Beta -public class SentencePieceEncoder implements Segmenter { +public class SentencePieceEncoder implements Segmenter, Encoder { // TODO: Support characters beyond BMP enum TokenType { text, control, userDefined, unknown, unused } @@ -94,6 +95,7 @@ public class SentencePieceEncoder implements Segmenter { * @param language the model to use, or Language.UNKNOWN to use the default model if any * @return the list of zero or more token ids resulting from segmenting the input text */ + @Override public List<Integer> encode(String rawInput, Language language) { var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) { public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) { @@ -106,8 +108,18 @@ public class SentencePieceEncoder implements Segmenter { } /** - * Encodes directly to a tensor. + * <p>Encodes directly to a tensor.</p> + * + * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order + * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small + * it will be truncated.</p> + * + * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token + * position as value.</p> + * + * <p>If the tensor is any other type IllegalArgumentException is thrown.</p> */ + @Override public Tensor encode(String rawInput, Language language, TensorType type) { if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { // Build to a list first since we can't reverse a tensor builder |