diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-09-28 21:19:41 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-09-28 21:19:41 +0200 |
commit | e7e659e9d26401c8c36300d4760d4e34acd26d0a (patch) | |
tree | 4c8b869a9ef991a6edda1c3a80e433b3b1690bbd /linguistics-components/src | |
parent | 35223653327b86a059d23c543bbac3611d43775f (diff) |
encode -> embed
Diffstat (limited to 'linguistics-components/src')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java (renamed from linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java) | 29 | ||||
-rw-r--r-- | linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def | 2 | ||||
-rw-r--r-- | linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java | 8 | ||||
-rw-r--r-- | linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java | 18 | ||||
-rw-r--r-- | linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java | 20 |
5 files changed, 38 insertions, 39 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index b6659ebeaa3..116dd15f563 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -4,7 +4,7 @@ package com.yahoo.language.sentencepiece; import com.google.common.annotations.Beta; import com.google.inject.Inject; import com.yahoo.language.Language; -import com.yahoo.language.process.Encoder; +import com.yahoo.language.process.Embedder; import com.yahoo.language.process.Segmenter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -19,26 +19,25 @@ import java.util.Map; import java.util.stream.Collectors; /** - * Integration with https://github.com/google/sentencepiece - * through http://docs.djl.ai/extensions/sentencepiece/index.html + * A native Java implementation of SentencePiece - see https://github.com/google/sentencepiece * - * SentencePiece is a language-agnostic tokenizer for neural nets. + * SentencePiece is a language-agnostic segmenter and embedder for neural nets. * * @author bratseth */ @Beta -public class SentencePieceEncoder implements Segmenter, Encoder { +public class SentencePieceEmbedder implements Segmenter, Embedder { private final Map<Language, Model> models; private final SentencePieceAlgorithm algorithm; @Inject - public SentencePieceEncoder(SentencePieceConfig config) { + public SentencePieceEmbedder(SentencePieceConfig config) { this(new Builder(config)); } - public SentencePieceEncoder(Builder builder) { + public SentencePieceEmbedder(Builder builder) { algorithm = new SentencePieceAlgorithm(builder.collapseUnknowns, builder.getScoring()); models = builder.getModels().entrySet() @@ -46,7 +45,7 @@ public class SentencePieceEncoder implements Segmenter, Encoder { .map(e -> new Model(e.getKey(), e.getValue())) .collect(Collectors.toUnmodifiableMap(m -> m.language, m -> m)); if (models.isEmpty()) - throw new IllegalArgumentException("SentencePieceEncoder requires at least one model configured"); + throw new IllegalArgumentException("SentencePieceEmbedder requires at least one model configured"); } /** @@ -77,7 +76,7 @@ public class SentencePieceEncoder implements Segmenter, Encoder { * @return the list of zero or more token ids resulting from segmenting the input text */ @Override - public List<Integer> encode(String rawInput, Language language) { + public List<Integer> embed(String rawInput, Language language) { var resultBuilder = new ResultBuilder<List<Integer>>(new ArrayList<>()) { public void add(int segmentStart, int segmentEnd, SentencePieceAlgorithm.SegmentEnd[] segmentEnds) { result().add(segmentEnds[segmentEnd].id); @@ -89,7 +88,7 @@ public class SentencePieceEncoder implements Segmenter, Encoder { } /** - * <p>Encodes directly to a tensor.</p> + * <p>Embeds text into a tensor.</p> * * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order * they were encountered in the text. If the dimension is bound and too large it will be zero padded, if too small @@ -101,10 +100,10 @@ public class SentencePieceEncoder implements Segmenter, Encoder { * <p>If the tensor is any other type IllegalArgumentException is thrown.</p> */ @Override - public Tensor encode(String rawInput, Language language, TensorType type) { + public Tensor embed(String rawInput, Language language, TensorType type) { if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { // Build to a list first since we can't reverse a tensor builder - List<Integer> values = encode(rawInput, language); + List<Integer> values = embed(rawInput, language); long maxSize = values.size(); if (type.dimensions().get(0).size().isPresent()) @@ -125,7 +124,7 @@ public class SentencePieceEncoder implements Segmenter, Encoder { return builder.build(); } else { - throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type); + throw new IllegalArgumentException("Don't know how to embed with SentencePiece into " + type); } } @@ -210,9 +209,9 @@ public class SentencePieceEncoder implements Segmenter, Encoder { } public Scoring getScoring() { return scoring; } - public SentencePieceEncoder build() { + public SentencePieceEmbedder build() { if (models.isEmpty()) throw new IllegalStateException("At least one model must be supplied"); - return new SentencePieceEncoder(this); + return new SentencePieceEmbedder(this); } } diff --git a/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def b/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def index b91c0c45dc4..16ada78688a 100644 --- a/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def +++ b/linguistics-components/src/main/resources/configdefinitions/language.sentencepiece.sentence-piece.def @@ -1,6 +1,6 @@ # Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -# Configures com.yahoo.language.sentencepiece.SentencePieceEncoder +# Configures com.yahoo.language.sentencepiece.SentencePieceEmbedder namespace=language.sentencepiece diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java index edbbe21ec53..1ed2271f774 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceConfigurationTest.java @@ -15,7 +15,7 @@ public class SentencePieceConfigurationTest { public void testEnglishTokenization() { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); - var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); tester.assertSegmented("KHJKJHHKJHHSH hello", "▁", "KHJKJHHKJHHSH", "▁hel", "lo"); } @@ -25,7 +25,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); b.collapseUnknowns(false); - var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); } @@ -34,7 +34,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("unknown", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); b.scoring(SentencePieceConfig.Scoring.highestScore); - var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented("hello", "▁h", "el", "lo"); } @@ -43,7 +43,7 @@ public class SentencePieceConfigurationTest { var b = new SentencePieceConfig.Builder(); addModel("ja", "src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model", b); addModel("en", "src/test/models/sentencepiece/en.wiki.bpe.vs10000.model", b); - var tester = new SentencePieceTester(new SentencePieceEncoder(b.build())); + var tester = new SentencePieceTester(new SentencePieceEmbedder(b.build())); tester.assertSegmented(Language.JAPANESE, "いくつかの通常のテキスト", "▁", "いく", "つか", "の", "通常", "の", "テ", "キ", "スト"); tester.assertSegmented(Language.ENGLISH, "hello", "▁hel", "lo"); tester.assertSegmented(Language.JAPANESE, "hello", "▁h", "ell", "o"); diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index d60d7386d4b..939f8ebe9d3 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -38,27 +38,27 @@ public class SentencePieceTest { @Test public void testIntegerListEncoding() { var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - tester.assertEncoded("hello, world!", 908, 1418, 9934, 501, 9960); - tester.assertEncoded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960); + tester.assertEmbedded("hello, world!", 908, 1418, 9934, 501, 9960); + tester.assertEmbedded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960); } @Test public void testDenseTensorEncoding() { var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - tester.assertEncoded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]"); - tester.assertEncoded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]"); - tester.assertEncoded("hello, world!", "tensor(d[2])", "[908,1418]"); + tester.assertEmbedded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]"); + tester.assertEmbedded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]"); + tester.assertEmbedded("hello, world!", "tensor(d[2])", "[908,1418]"); } @Test public void testSparseTensorEncoding() { var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); - tester.assertEncoded("hello", "tensor(token{})", "{lo:1.0,'▁hel':0.0}"); + tester.assertEmbedded("hello", "tensor(token{})", "{lo:1.0,'▁hel':0.0}"); } @Test public void testNoCollapse() { - var tester = new SentencePieceTester(new SentencePieceEncoder.Builder() + var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder() .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) .setCollapseUnknowns(false)); tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); @@ -66,7 +66,7 @@ public class SentencePieceTest { @Test public void testHighestScore() { - var tester = new SentencePieceTester(new SentencePieceEncoder.Builder() + var tester = new SentencePieceTester(new SentencePieceEmbedder.Builder() .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) .setScoring(Scoring.highestScore)); tester.assertSegmented("h", "▁h"); @@ -77,7 +77,7 @@ public class SentencePieceTest { @Test public void testMultiLanguageTokenization() { - SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder(); + SentencePieceEmbedder.Builder builder = new SentencePieceEmbedder.Builder(); builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); builder.addModel(Language.ENGLISH, new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); var tester = new SentencePieceTester(builder); diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java index 1ba7c9b472d..c4cb13a3d23 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java @@ -14,28 +14,28 @@ import static org.junit.Assert.assertEquals; class SentencePieceTester { - private final SentencePieceEncoder encoder; + private final SentencePieceEmbedder embedder; public SentencePieceTester(Path model) { - this(new SentencePieceEncoder.Builder().addDefaultModel(model)); + this(new SentencePieceEmbedder.Builder().addDefaultModel(model)); } - public SentencePieceTester(SentencePieceEncoder.Builder builder) { + public SentencePieceTester(SentencePieceEmbedder.Builder builder) { this(builder.build()); } - public SentencePieceTester(SentencePieceEncoder encoder) { - this.encoder = encoder; + public SentencePieceTester(SentencePieceEmbedder embedder) { + this.embedder = embedder; } - public void assertEncoded(String input, Integer... expectedCodes) { - assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray()); + public void assertEmbedded(String input, Integer... expectedCodes) { + assertArrayEquals(expectedCodes, embedder.embed(input, Language.UNKNOWN).toArray()); } - public void assertEncoded(String input, String tensorType, String tensor) { + public void assertEmbedded(String input, String tensorType, String tensor) { TensorType type = TensorType.fromSpec(tensorType); Tensor expected = Tensor.from(type, tensor); - assertEquals(expected, encoder.encode(input, Language.UNKNOWN, type)); + assertEquals(expected, embedder.embed(input, Language.UNKNOWN, type)); } public void assertSegmented(String input, String... expectedSegments) { @@ -43,7 +43,7 @@ class SentencePieceTester { } public void assertSegmented(Language language, String input, String... expectedSegments) { - assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray()); + assertArrayEquals(expectedSegments, embedder.segment(input, language).toArray()); } } |