diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-09-16 13:35:13 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-09-16 13:35:13 +0200 |
commit | 8e7357639fcc2805f867162ae62d710888606d56 (patch) | |
tree | 60fbae0b8d23c049dec4cffab008010dee4a9d06 /linguistics | |
parent | f84bde56a310096c7019c7d899573e36ea5a7316 (diff) |
Encode to sparse tensor
Diffstat (limited to 'linguistics')
3 files changed, 17 insertions, 0 deletions
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 8df0848870e..136d07721de 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -882,6 +882,7 @@ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder)", "public java.util.List segment(java.lang.String, com.yahoo.language.Language)", "public java.util.List encode(java.lang.String, com.yahoo.language.Language)", + "public com.yahoo.tensor.Tensor encode(java.lang.String, com.yahoo.language.Language, com.yahoo.tensor.TensorType)", "public java.lang.String normalize(java.lang.String)" ], "fields": [] diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java index 31b85c75314..c7b131cc439 100644 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -8,6 +8,7 @@ import com.yahoo.io.IOUtils; import com.yahoo.language.Language; import com.yahoo.language.process.Segmenter; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import sentencepiece.SentencepieceModel; @@ -121,6 +122,15 @@ public class SentencePieceEncoder implements Segmenter { builder.cell(values.get(i), i); return builder.build(); } + else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) { + // Build to a list first since we can't reverse a tensor builder + List<String> values = segment(rawInput, language); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int i = 0; i < values.size(); i++) + builder.cell(TensorAddress.ofLabels(values.get(i)), i); + return builder.build(); + } else { throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type); } diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index b8fb8a2fdbb..5b77324a6fc 100644 --- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -52,6 +52,12 @@ public class SentencePieceTest { } @Test + public void testSparseTensorEncoding() { + var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + tester.assertEncoded("hello", "tensor(token{})", "{lo:1.0,'▁hel':0.0}"); + } + + @Test public void testNoCollapse() { var tester = new SentencePieceTester(new SentencePieceEncoder.Builder() .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) |