summaryrefslogtreecommitdiffstats
path: root/linguistics
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-16 13:35:13 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-16 13:35:13 +0200
commit8e7357639fcc2805f867162ae62d710888606d56 (patch)
tree60fbae0b8d23c049dec4cffab008010dee4a9d06 /linguistics
parentf84bde56a310096c7019c7d899573e36ea5a7316 (diff)
Encode to sparse tensor
Diffstat (limited to 'linguistics')
-rw-r--r--linguistics/abi-spec.json1
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java10
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java6
3 files changed, 17 insertions, 0 deletions
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 8df0848870e..136d07721de 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -882,6 +882,7 @@
"public void <init>(com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder)",
"public java.util.List segment(java.lang.String, com.yahoo.language.Language)",
"public java.util.List encode(java.lang.String, com.yahoo.language.Language)",
+ "public com.yahoo.tensor.Tensor encode(java.lang.String, com.yahoo.language.Language, com.yahoo.tensor.TensorType)",
"public java.lang.String normalize(java.lang.String)"
],
"fields": []
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
index 31b85c75314..c7b131cc439 100644
--- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
@@ -8,6 +8,7 @@ import com.yahoo.io.IOUtils;
import com.yahoo.language.Language;
import com.yahoo.language.process.Segmenter;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import sentencepiece.SentencepieceModel;
@@ -121,6 +122,15 @@ public class SentencePieceEncoder implements Segmenter {
builder.cell(values.get(i), i);
return builder.build();
}
+ else if (type.dimensions().size() == 1 && type.dimensions().get(0).isMapped()) {
+ // Build to a list first since we can't reverse a tensor builder
+ List<String> values = segment(rawInput, language);
+
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ for (int i = 0; i < values.size(); i++)
+ builder.cell(TensorAddress.ofLabels(values.get(i)), i);
+ return builder.build();
+ }
else {
throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type);
}
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index b8fb8a2fdbb..5b77324a6fc 100644
--- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -52,6 +52,12 @@ public class SentencePieceTest {
}
@Test
+ public void testSparseTensorEncoding() {
+ var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ tester.assertEncoded("hello", "tensor(token{})", "{lo:1.0,'▁hel':0.0}");
+ }
+
+ @Test
public void testNoCollapse() {
var tester = new SentencePieceTester(new SentencePieceEncoder.Builder()
.addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())