summaryrefslogtreecommitdiffstats
path: root/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-16 22:46:17 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-16 22:46:17 +0200
commit583101eec4032dda2310df146962e8471d70f188 (patch)
treee8f9f53f29d7c2bf82665d3256a7f333c696d1b8 /linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
parentdaab62042f34575d545dcd0b6fd100e232848c85 (diff)
Encoder interface
Diffstat (limited to 'linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java')
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java16
1 files changed, 14 insertions, 2 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
index c7b131cc439..74ed79b267b 100644
--- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
@@ -6,6 +6,7 @@ import com.google.common.annotations.Beta;
import com.google.inject.Inject;
import com.yahoo.io.IOUtils;
import com.yahoo.language.Language;
+import com.yahoo.language.process.Encoder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -30,7 +31,7 @@ import java.util.stream.Collectors;
* @author bratseth
*/
@Beta
-public class SentencePieceEncoder implements Segmenter {
+public class SentencePieceEncoder implements Segmenter, Encoder {
// TODO: Support characters beyond BMP
enum TokenType { text, control, userDefined, unknown, unused }
@@ -94,6 +95,7 @@ public class SentencePieceEncoder implements Segmenter {
* @param language the model to use, or Language.UNKNOWN to use the default model if any
* @return the list of zero or more token ids resulting from segmenting the input text
*/
+ @Override
public List<Integer> encode(String rawInput, Language language) {
var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) {
public void add(int segmentStart, int segmentEnd, SegmentEnd[] segmentEnds) {
@@ -106,8 +108,18 @@ public class SentencePieceEncoder implements Segmenter {
}
/**
- * Encodes directly to a tensor.
+ * <p>Encodes directly to a tensor.</p>
+ *
+ * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order
+ * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small
+ * it will be truncated.</p>
+ *
+ * <p>If the tensor type is1-d sparse this will return a tensor containing the token strings as keys and the token
+ * position as value.</p>
+ *
+ * <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
*/
+ @Override
public Tensor encode(String rawInput, Language language, TensorType type) {
if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
// Build to a list first since we can't reverse a tensor builder