diff options
Diffstat (limited to 'linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java')
-rw-r--r-- | linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java index 9a43d22ca4b..31b85c75314 100644 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -18,7 +18,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import java.util.stream.Collectors; /** @@ -108,9 +107,19 @@ public class SentencePieceEncoder implements Segmenter { /** * Encodes directly to a tensor. */ - public Tensor encode(String input, Language language, TensorType type) { + public Tensor encode(String rawInput, Language language, TensorType type) { if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { - return null; + // Build to a list first since we can't reverse a tensor builder + List<Integer> values = encode(rawInput, language); + + long maxSize = values.size(); + if (type.dimensions().get(0).size().isPresent()) + maxSize = Math.min(maxSize, type.dimensions().get(0).size().get()); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int i = 0; i < maxSize; i++) + builder.cell(values.get(i), i); + return builder.build(); } else { throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type); @@ -185,7 +194,7 @@ public class SentencePieceEncoder implements Segmenter { private static abstract class ResultBuilder<RESULTTYPE> { - private RESULTTYPE result; + private final RESULTTYPE result; ResultBuilder(RESULTTYPE result) { this.result = result; |