summaryrefslogtreecommitdiffstats
path: root/linguistics-components
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-28 21:19:41 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-28 21:19:41 +0200
commite7e659e9d26401c8c36300d4760d4e34acd26d0a (patch)
tree4c8b869a9ef991a6edda1c3a80e433b3b1690bbd /linguistics-components
parent35223653327b86a059d23c543bbac3611d43775f (diff)
encode -> embed
Diffstat (limited to 'linguistics-components')
-rw-r--r--linguistics-components/abi-spec.json20
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java (renamed from linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java)29
-rw-r--r--linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def2
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java8
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java18
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java20
6 files changed, 48 insertions, 49 deletions
diff --git a/linguistics-components/abi-spec.json b/linguistics-components/abi-spec.json
index 5b6729c58ef..808ec3af082 100644
--- a/linguistics-components/abi-spec.json
+++ b/linguistics-components/abi-spec.json
@@ -148,7 +148,7 @@
"public static final java.lang.String[] CONFIG_DEF_SCHEMA"
]
},
- "com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder": {
+ "com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder": {
"superClass": "java.lang.Object",
"interfaces": [],
"attributes": [
@@ -157,31 +157,31 @@
"methods": [
"public void <init>()",
"public void addModel(com.yahoo.language.Language, java.nio.file.Path)",
- "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder addDefaultModel(java.nio.file.Path)",
+ "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder addDefaultModel(java.nio.file.Path)",
"public java.util.Map getModels()",
- "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder setCollapseUnknowns(boolean)",
+ "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setCollapseUnknowns(boolean)",
"public boolean getCollapseUnknowns()",
- "public com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder setScoring(com.yahoo.language.sentencepiece.Scoring)",
+ "public com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder setScoring(com.yahoo.language.sentencepiece.Scoring)",
"public com.yahoo.language.sentencepiece.Scoring getScoring()",
- "public com.yahoo.language.sentencepiece.SentencePieceEncoder build()"
+ "public com.yahoo.language.sentencepiece.SentencePieceEmbedder build()"
],
"fields": []
},
- "com.yahoo.language.sentencepiece.SentencePieceEncoder": {
+ "com.yahoo.language.sentencepiece.SentencePieceEmbedder": {
"superClass": "java.lang.Object",
"interfaces": [
"com.yahoo.language.process.Segmenter",
- "com.yahoo.language.process.Encoder"
+ "com.yahoo.language.process.Embedder"
],
"attributes": [
"public"
],
"methods": [
"public void <init>(com.yahoo.language.sentencepiece.SentencePieceConfig)",
- "public void <init>(com.yahoo.language.sentencepiece.SentencePieceEncoder$Builder)",
+ "public void <init>(com.yahoo.language.sentencepiece.SentencePieceEmbedder$Builder)",
"public java.util.List segment(java.lang.String, com.yahoo.language.Language)",
- "public java.util.List encode(java.lang.String, com.yahoo.language.Language)",
- "public com.yahoo.tensor.Tensor encode(java.lang.String, com.yahoo.language.Language, com.yahoo.tensor.TensorType)",
+ "public java.util.List embed(java.lang.String, com.yahoo.language.Language)",
+ "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.Language, com.yahoo.tensor.TensorType)",
"public java.lang.String normalize(java.lang.String)"
],
"fields": []
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index b6659ebeaa3..116dd15f563 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -4,7 +4,7 @@ package com.yahoo.language.sentencepiece;
import com.google.common.annotations.Beta;
import com.google.inject.Inject;
import com.yahoo.language.Language;
-import com.yahoo.language.process.Encoder;
+import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -19,26 +19,25 @@ import java.util.Map;
import java.util.stream.Collectors;
/**
- * Integration with https://github.com/google/sentencepiece
- * through http://docs.djl.ai/extensions/sentencepiece/index.html
+ * A native Java implementation of SentencePiece - see https://github.com/google/sentencepiece
*
- * SentencePiece is a language-agnostic tokenizer for neural nets.
+ * SentencePiece is a language-agnostic segmenter and embedder for neural nets.
*
* @author bratseth
*/
@Beta
-public class SentencePieceEncoder implements Segmenter, Encoder {
+public class SentencePieceEmbedder implements Segmenter, Embedder {
private final Map<Language, Model> models;
private final SentencePieceAlgorithm algorithm;
@Inject
- public SentencePieceEncoder(SentencePieceConfig config) {
+ public SentencePieceEmbedder(SentencePieceConfig config) {
this(new Builder(config));
}
- public SentencePieceEncoder(Builder builder) {
+ public SentencePieceEmbedder(Builder builder) {
algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring());
models = builder.getModels().entrySet()
@@ -46,7 +45,7 @@ public class SentencePieceEncoder implements Segmenter, Encoder {
.map(e -> new Model(e.getKey(), e.getValue()))
.collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m));
if (models.isEmpty())
- throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured");
+ throw new IllegalArgumentException("SentencePieceEmbedder requires at least one model configured");
}
/**
@@ -77,7 +76,7 @@ public class SentencePieceEncoder implements Segmenter, Encoder {
* @return the list of zero or more token ids resulting from segmenting the input text
*/
@Override
- public List<Integer> encode(String rawInput, Language language) {
+ public List<Integer> embed(String rawInput, Language language) {
var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) {
public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) {
result().add(segmentEnds[segmentEnd].id);
@@ -89,7 +88,7 @@ public class SentencePieceEncoder implements Segmenter, Encoder {
}
/**
- * <p>Encodes directly to a tensor.</p>
+ * <p>Embeds text into a tensor.</p>
*
* <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order
* they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small
@@ -101,10 +100,10 @@ public class SentencePieceEncoder implements Segmenter, Encoder {
* <p>If the tensor is any other type IllegalArgumentException is thrown.</p>
*/
@Override
- public Tensor encode(String rawInput, Language language, TensorType type) {
+ public Tensor embed(String rawInput, Language language, TensorType type) {
if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) {
// Build to a list first since we can't reverse a tensor builder
- List<Integer> values = encode(rawInput, language);
+ List<Integer> values = embed(rawInput, language);
long maxSize = values.size();
if (type.dimensions().get(0).size().isPresent())
@@ -125,7 +124,7 @@ public class SentencePieceEncoder implements Segmenter, Encoder {
return builder.build();
}
else {
- throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type);
+ throw new IllegalArgumentException("Don't know how to embed with SentencePiece into " + type);
}
}
@@ -210,9 +209,9 @@ public class SentencePieceEncoder implements Segmenter, Encoder {
}
public Scoring getScoring() { return scoring; }
- public SentencePieceEncoder build() {
+ public SentencePieceEmbedder build() {
if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied");
- return new SentencePieceEncoder(this);
+ return new SentencePieceEmbedder(this);
}
}
diff --git a/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def b/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def
index b91c0c45dc4..16ada78688a 100644
--- a/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def
+++ b/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def
@@ -1,6 +1,6 @@
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-# Configures com.yahoo.language.sentencepiece.SentencePieceEncoder
+# Configures com.yahoo.language.sentencepiece.SentencePieceEmbedder
namespace=language.sentencepiece
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
index edbbe21ec53..1ed2271f774 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java
@@ -15,7 +15,7 @@ public class SentencePieceConfigurationTest {
public void testEnglishTokenization() {
var b = new SentencePieceConfig.Builder();
addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
- var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence");
tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo");
}
@@ -25,7 +25,7 @@ public class SentencePieceConfigurationTest {
var b = new SentencePieceConfig.Builder();
addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
b.collapseUnknowns(false);
- var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
}
@@ -34,7 +34,7 @@ public class SentencePieceConfigurationTest {
var b = new SentencePieceConfig.Builder();
addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
b.scoring(SentencePieceConfig.Scoring.highestScore);
- var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented("hello", "▁h", "el", "lo");
}
@@ -43,7 +43,7 @@ public class SentencePieceConfigurationTest {
var b = new SentencePieceConfig.Builder();
addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b);
addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b);
- var tester = new SentencePieceTester(new SentencePieceEncoder(b.build()));
+ var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build()));
tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト");
tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo");
tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o");
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index d60d7386d4b..939f8ebe9d3 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -38,27 +38,27 @@ public class SentencePieceTest {
@Test
public void testIntegerListEncoding() {
var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
- tester.assertEncoded("hello, world!", 908, 1418, 9934, 501, 9960);
- tester.assertEncoded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960);
+ tester.assertEmbedded("hello, world!", 908, 1418, 9934, 501, 9960);
+ tester.assertEmbedded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960);
}
@Test
public void testDenseTensorEncoding() {
var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
- tester.assertEncoded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]");
- tester.assertEncoded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]");
- tester.assertEncoded("hello, world!", "tensor(d[2])", "[908,1418]");
+ tester.assertEmbedded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]");
+ tester.assertEmbedded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]");
+ tester.assertEmbedded("hello, world!", "tensor(d[2])", "[908,1418]");
}
@Test
public void testSparseTensorEncoding() {
var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
- tester.assertEncoded("hello", "tensor(token{})", "{lo:1.0,'▁hel':0.0}");
+ tester.assertEmbedded("hello", "tensor(token{})", "{lo:1.0,'▁hel':0.0}");
}
@Test
public void testNoCollapse() {
- var tester = new SentencePieceTester(new SentencePieceEncoder.Builder()
+ var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder()
.addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
.setCollapseUnknowns(false));
tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
@@ -66,7 +66,7 @@ public class SentencePieceTest {
@Test
public void testHighestScore() {
- var tester = new SentencePieceTester(new SentencePieceEncoder.Builder()
+ var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder()
.addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
.setScoring(Scoring.highestScore));
tester.assertSegmented("h", "▁h");
@@ -77,7 +77,7 @@ public class SentencePieceTest {
@Test
public void testMultiLanguageTokenization() {
- SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder();
+ SentencePieceEmbedder.Builder builder = new SentencePieceEmbedder.Builder();
builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath());
builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
var tester = new SentencePieceTester(builder);
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
index 1ba7c9b472d..c4cb13a3d23 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -14,28 +14,28 @@ import static org.junit.Assert.assertEquals;
class SentencePieceTester {
- private final SentencePieceEncoder encoder;
+ private final SentencePieceEmbedder embedder;
public SentencePieceTester(Path model) {
- this(new SentencePieceEncoder.Builder().addDefaultModel(model));
+ this(new SentencePieceEmbedder.Builder().addDefaultModel(model));
}
- public SentencePieceTester(SentencePieceEncoder.Builder builder) {
+ public SentencePieceTester(SentencePieceEmbedder.Builder builder) {
this(builder.build());
}
- public SentencePieceTester(SentencePieceEncoder encoder) {
- this.encoder = encoder;
+ public SentencePieceTester(SentencePieceEmbedder embedder) {
+ this.embedder = embedder;
}
- public void assertEncoded(String input, Integer... expectedCodes) {
- assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
+ public void assertEmbedded(String input, Integer... expectedCodes) {
+ assertArrayEquals(expectedCodes, embedder.embed(input, Language.UNKNOWN).toArray());
}
- public void assertEncoded(String input, String tensorType, String tensor) {
+ public void assertEmbedded(String input, String tensorType, String tensor) {
TensorType type = TensorType.fromSpec(tensorType);
Tensor expected = Tensor.from(type, tensor);
- assertEquals(expected, encoder.encode(input, Language.UNKNOWN, type));
+ assertEquals(expected, embedder.embed(input, Language.UNKNOWN, type));
}
public void assertSegmented(String input, String... expectedSegments) {
@@ -43,7 +43,7 @@ class SentencePieceTester {
}
public void assertSegmented(Language language, String input, String... expectedSegments) {
- assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
+ assertArrayEquals(expectedSegments, embedder.segment(input, language).toArray());
}
}