From f1921848eff763bc99c46e53733df7bcae04fa7b Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 16 Jan 2017 15:55:41 +0100 Subject: Add tensor document summary field --- .../com/yahoo/io/GrowableBufferOutputStream.java | 6 ++-- .../yahoo/tensor/serialization/BinaryFormat.java | 6 ++-- .../tensor/serialization/DenseBinaryFormat.java | 33 +++++++++++++++++++--- .../tensor/serialization/SparseBinaryFormat.java | 19 +++++++++++-- .../tensor/serialization/TypedBinaryFormat.java | 13 +++++++-- .../serialization/DenseBinaryFormatTestCase.java | 4 ++- .../serialization/SparseBinaryFormatTestCase.java | 5 +++- 7 files changed, 70 insertions(+), 16 deletions(-) (limited to 'vespajlib') diff --git a/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java b/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java index 85b249432d4..b8dfedc8ede 100644 --- a/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java +++ b/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java @@ -9,13 +9,11 @@ import java.util.LinkedList; import java.util.Iterator; import java.nio.ByteBuffer; - /** - * - * @author Bjorn Borud + * @author Bjørn Borud */ public class GrowableBufferOutputStream extends OutputStream { -// private static final int MINIMUM_BUFFERSIZE = (64 * 1024); + private ByteBuffer lastBuffer; private ByteBuffer directBuffer; private LinkedList bufferList = new LinkedList<>(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java index 9b0ccdcb6c8..a6949fdf57f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java @@ -6,6 +6,8 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.Optional; + /** * Representation of a specific binary format with functions for serializing a Tensor object into * this format or de-serializing binary data into a Tensor object. @@ -23,9 +25,9 @@ interface BinaryFormat { /** * Deserialize the given binary data into a Tensor object. * - * @param type the expected abstract type of the tensor to serialize + * @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data * @param buffer the buffer containing the tensor binary data */ - Tensor decode(TensorType type, GrowableByteBuffer buffer); + Tensor decode(Optional type, GrowableByteBuffer buffer); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 0a97576d5b7..3ff82ea774b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -6,9 +6,9 @@ import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import com.yahoo.text.Utf8; import java.util.Iterator; +import java.util.Optional; /** * Implementation of a dense binary format for a tensor on the form: @@ -46,14 +46,23 @@ public class DenseBinaryFormat implements BinaryFormat { } @Override - public Tensor decode(TensorType type, GrowableByteBuffer buffer) { - DimensionSizes sizes = decodeDimensionSizes(type, buffer); + public Tensor decode(Optional optionalType, GrowableByteBuffer buffer) { + TensorType type; + DimensionSizes sizes; + if (optionalType.isPresent()) { + type = optionalType.get(); + sizes = decodeAndValidateDimensionSizes(type, buffer); + } + else { + type = decodeType(buffer); + sizes = sizesFromType(type); + } Tensor.Builder builder = Tensor.Builder.of(type, sizes); decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder); return builder.build(); } - private DimensionSizes decodeDimensionSizes(TensorType type, GrowableByteBuffer buffer) { + private DimensionSizes decodeAndValidateDimensionSizes(TensorType type, GrowableByteBuffer buffer) { int dimensionCount = buffer.getInt1_4Bytes(); if (type.dimensions().size() != dimensionCount) throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + @@ -79,6 +88,22 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } + private TensorType decodeType(GrowableByteBuffer buffer) { + int dimensionCount = buffer.getInt1_4Bytes(); + TensorType.Builder builder = new TensorType.Builder(); + for (int i = 0; i < dimensionCount; i++) + builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); + return builder.build(); + } + + /** Returns dimension sizes from a type consisting of fully specified, indexed dimensions only */ + private DimensionSizes sizesFromType(TensorType type) { + DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); + for (int i = 0; i < type.dimensions().size(); i++) + builder.set(i, type.dimensions().get(i).size().get()); + return builder.build(); + } + private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { for (int i = 0; i < sizes.totalSize(); i++) builder.cellByDirectIndex(i, buffer.getDouble()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 8ab23c8d77c..6b0443c9bfe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -53,8 +53,15 @@ class SparseBinaryFormat implements BinaryFormat { } @Override - public Tensor decode(TensorType type, GrowableByteBuffer buffer) { - consumeAndValidateDimensions(type, buffer); + public Tensor decode(Optional optionalType, GrowableByteBuffer buffer) { + TensorType type; + if (optionalType.isPresent()) { + type = optionalType.get(); + consumeAndValidateDimensions(optionalType.get(), buffer); + } + else { + type = decodeType(buffer); + } Tensor.Builder builder = Tensor.Builder.of(type); decodeCells(buffer, builder, type); return builder.build(); @@ -75,6 +82,14 @@ class SparseBinaryFormat implements BinaryFormat { } } + private TensorType decodeType(GrowableByteBuffer buffer) { + int numDimensions = buffer.getInt1_4Bytes(); + TensorType.Builder builder = new TensorType.Builder(); + for (int i = 0; i < numDimensions; ++i) + builder.mapped(buffer.getUtf8String()); + return builder.build(); + } + private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { int numCells = buffer.getInt1_4Bytes(); for (int i = 0; i < numCells; ++i) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 19c1810d928..6413602c532 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -7,6 +7,8 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.Optional; + /** * Class used by clients for serializing a Tensor object into binary format or * de-serializing binary data into a Tensor object. @@ -38,8 +40,15 @@ public class TypedBinaryFormat { return result; } - public static Tensor decode(TensorType type, byte[] data) { - GrowableByteBuffer buffer = GrowableByteBuffer.wrap(data); + /** + * Decode some data to a tensor + * + * @param type the type to decode and validate to, or empty to use the type given in the data + * @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array + * @return the resulting tensor + * @throws IllegalArgumentException if the tensor data was invalid + */ + public static Tensor decode(Optional type, GrowableByteBuffer buffer) { int formatType = buffer.getInt1_4Bytes(); switch (formatType) { case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 15e82e6b15c..8a3d2879201 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -2,11 +2,13 @@ package com.yahoo.tensor.serialization; import com.google.common.collect.Sets; +import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; import org.junit.Ignore; import org.junit.Test; import java.util.Arrays; +import java.util.Optional; import java.util.Set; import static org.junit.Assert.assertEquals; @@ -46,7 +48,7 @@ public class DenseBinaryFormatTestCase { private void assertSerialization(Tensor tensor) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor); + Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(tensor.type()), GrowableByteBuffer.wrap(encodedTensor)); assertEquals(tensor, decodedTensor); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 283aa90cf65..65f6b92f91e 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -2,10 +2,12 @@ package com.yahoo.tensor.serialization; import com.google.common.collect.Sets; +import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; import org.junit.Test; import java.util.Arrays; +import java.util.Optional; import java.util.Set; import static org.junit.Assert.assertEquals; @@ -46,7 +48,8 @@ public class SparseBinaryFormatTestCase { private void assertSerialization(Tensor tensor) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor); + Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(tensor.type()), + GrowableByteBuffer.wrap(encodedTensor)); assertEquals(tensor, decodedTensor); } -- cgit v1.2.3