From f84bde56a310096c7019c7d899573e36ea5a7316 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 16 Sep 2021 13:18:11 +0200 Subject: Encode to dense tensor --- .../language/sentencepiece/SentencePieceEncoder.java | 17 +++++++++++++---- .../yahoo/language/sentencepiece/SentencePieceTest.java | 14 ++++++++++++++ .../language/sentencepiece/SentencePieceTester.java | 9 +++++++++ .../src/main/java/com/yahoo/tensor/IndexedTensor.java | 4 +--- 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java index 9a43d22ca4b..31b85c75314 100644 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -18,7 +18,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import java.util.stream.Collectors; /** @@ -108,9 +107,19 @@ public class SentencePieceEncoder implements Segmenter { /** * Encodes directly to a tensor. */ - public Tensor encode(String input, Language language, TensorType type) { + public Tensor encode(String rawInput, Language language, TensorType type) { if (type.dimensions().size() == 1 && type.dimensions().get(0).isIndexed()) { - return null; + // Build to a list first since we can't reverse a tensor builder + List values = encode(rawInput, language); + + long maxSize = values.size(); + if (type.dimensions().get(0).size().isPresent()) + maxSize = Math.min(maxSize, type.dimensions().get(0).size().get()); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int i = 0; i < maxSize; i++) + builder.cell(values.get(i), i); + return builder.build(); } else { throw new IllegalArgumentException("Don't know how to encode with SentencePiece into " + type); @@ -185,7 +194,7 @@ public class SentencePieceEncoder implements Segmenter { private static abstract class ResultBuilder { - private RESULTTYPE result; + private final RESULTTYPE result; ResultBuilder(RESULTTYPE result) { this.result = result; diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index f86bc2f716b..b8fb8a2fdbb 100644 --- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -3,6 +3,7 @@ package com.yahoo.language.sentencepiece; import com.yahoo.language.Language; +import com.yahoo.tensor.Tensor; import org.junit.Test; import java.io.File; @@ -33,10 +34,23 @@ public class SentencePieceTest { tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁)", "(", "/", "&", "#", "(", "sm", "all", ")", "/", "\"", "in", "▁qu", "otes", "\")"); tester.assertSegmented("x.400AS", "▁x", ".", "4", "00", "AS"); tester.assertSegmented("A normal sentence. Yes one more.", "▁", "A", "▁normal", "▁sentence", ".", "▁", "Y", "es", "▁one", "▁more", "."); + } + + @Test + public void testIntegerListEncoding() { + var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); tester.assertEncoded("hello, world!", 908, 1418, 9934, 501, 9960); tester.assertEncoded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960); } + @Test + public void testDenseTensorEncoding() { + var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); + tester.assertEncoded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]"); + tester.assertEncoded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]"); + tester.assertEncoded("hello, world!", "tensor(d[2])", "[908,1418]"); + } + @Test public void testNoCollapse() { var tester = new SentencePieceTester(new SentencePieceEncoder.Builder() diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java index dee9be5aa7e..1ba7c9b472d 100644 --- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java @@ -4,10 +4,13 @@ package com.yahoo.language.sentencepiece; import com.yahoo.language.Language; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.nio.file.Path; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; class SentencePieceTester { @@ -29,6 +32,12 @@ class SentencePieceTester { assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray()); } + public void assertEncoded(String input, String tensorType, String tensor) { + TensorType type = TensorType.fromSpec(tensorType); + Tensor expected = Tensor.from(type, tensor); + assertEquals(expected, encoder.encode(input, Language.UNKNOWN, type)); + } + public void assertSegmented(String input, String... expectedSegments) { assertSegmented(Language.UNKNOWN, input, expectedSegments); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index d822a5c6b8b..b618f935f5c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -401,8 +401,6 @@ public abstract class IndexedTensor implements Tensor { TensorType type(); - - /** Sets a value by its standard value order index */ void cellByDirectIndex(long index, double value); @@ -414,7 +412,7 @@ public abstract class IndexedTensor implements Tensor { /** A bound builder can create the double array directly */ public static abstract class BoundBuilder extends Builder implements DirectIndexBuilder { - private DimensionSizes sizes; + private final DimensionSizes sizes; private static DimensionSizes dimensionSizesOf(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); -- cgit v1.2.3