summaryrefslogtreecommitdiffstats
path: root/linguistics/src/main/java/com/yahoo/language
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-16 13:18:11 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-16 13:18:11 +0200
commitf84bde56a310096c7019c7d899573e36ea5a7316 (patch)
treee2e6d0a5c7d27ac540e3d82132a988b9d31890de /linguistics/src/main/java/com/yahoo/language
parent39a052f324aeeb4c0eb2d4313edf57ddbc4db2c7 (diff)
Encode to dense tensor
Diffstat (limited to 'linguistics/src/main/java/com/yahoo/language')
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java17
1 files changed, 13 insertions, 4 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
index 9a43d22ca4b..31b85c75314 100644
--- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
@@ -18,7 +18,6 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.function.BiFunction;
import java.util.stream.Collectors;
/**
@@ -108,9 +107,19 @@ public class SentencePieceEncoder implements Segmenter {
/**
* Encodes directly to a tensor.
*/
- public Tensor encode(String input, Language language, TensorType type) {
+ public Tensor encode(String rawInput, Language language, TensorType type) {
if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
- return null;
+ // Build to a list first since we can't reverse a tensor builder
+ List<Integer> values = encode(rawInput, language);
+
+ long maxSize = values.size();
+ if (type.dimensions().get(0).size().isPresent())
+ maxSize = Math.min(maxSize, type.dimensions().get(0).size().get());
+
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ for (int i = 0; i < maxSize; i++)
+ builder.cell(values.get(i), i);
+ return builder.build();
}
else {
throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type);
@@ -185,7 +194,7 @@ public class SentencePieceEncoder implements Segmenter {
private static abstract class ResultBuilder<RESULTTYPE> {
- private RESULTTYPE result;
+ private final RESULTTYPE result;
ResultBuilder(RESULTTYPE result) {
this.result = result;