diff options
Diffstat (limited to 'vespajlib/src/main/java')
5 files changed, 63 insertions, 14 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java b/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java index 85b249432d4..b8dfedc8ede 100644 --- a/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java +++ b/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java @@ -9,13 +9,11 @@ import java.util.LinkedList; import java.util.Iterator; import java.nio.ByteBuffer; - /** - * - * @author <a href="mailto:borud@yahoo-inc.com">Bjorn Borud</a> + * @author Bjørn Borud */ public class GrowableBufferOutputStream extends OutputStream { -// private static final int MINIMUM_BUFFERSIZE = (64 * 1024); + private ByteBuffer lastBuffer; private ByteBuffer directBuffer; private LinkedList<ByteBuffer> bufferList = new LinkedList<>(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java index 9b0ccdcb6c8..a6949fdf57f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java @@ -6,6 +6,8 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.Optional; + /** * Representation of a specific binary format with functions for serializing a Tensor object into * this format or de-serializing binary data into a Tensor object. @@ -23,9 +25,9 @@ interface BinaryFormat { /** * Deserialize the given binary data into a Tensor object. * - * @param type the expected abstract type of the tensor to serialize + * @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data * @param buffer the buffer containing the tensor binary data */ - Tensor decode(TensorType type, GrowableByteBuffer buffer); + Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 0a97576d5b7..3ff82ea774b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -6,9 +6,9 @@ import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import com.yahoo.text.Utf8; import java.util.Iterator; +import java.util.Optional; /** * Implementation of a dense binary format for a tensor on the form: @@ -46,14 +46,23 @@ public class DenseBinaryFormat implements BinaryFormat { } @Override - public Tensor decode(TensorType type, GrowableByteBuffer buffer) { - DimensionSizes sizes = decodeDimensionSizes(type, buffer); + public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) { + TensorType type; + DimensionSizes sizes; + if (optionalType.isPresent()) { + type = optionalType.get(); + sizes = decodeAndValidateDimensionSizes(type, buffer); + } + else { + type = decodeType(buffer); + sizes = sizesFromType(type); + } Tensor.Builder builder = Tensor.Builder.of(type, sizes); decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder); return builder.build(); } - private DimensionSizes decodeDimensionSizes(TensorType type, GrowableByteBuffer buffer) { + private DimensionSizes decodeAndValidateDimensionSizes(TensorType type, GrowableByteBuffer buffer) { int dimensionCount = buffer.getInt1_4Bytes(); if (type.dimensions().size() != dimensionCount) throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + @@ -79,6 +88,22 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } + private TensorType decodeType(GrowableByteBuffer buffer) { + int dimensionCount = buffer.getInt1_4Bytes(); + TensorType.Builder builder = new TensorType.Builder(); + for (int i = 0; i < dimensionCount; i++) + builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); + return builder.build(); + } + + /** Returns dimension sizes from a type consisting of fully specified, indexed dimensions only */ + private DimensionSizes sizesFromType(TensorType type) { + DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); + for (int i = 0; i < type.dimensions().size(); i++) + builder.set(i, type.dimensions().get(i).size().get()); + return builder.build(); + } + private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { for (int i = 0; i < sizes.totalSize(); i++) builder.cellByDirectIndex(i, buffer.getDouble()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 8ab23c8d77c..6b0443c9bfe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -53,8 +53,15 @@ class SparseBinaryFormat implements BinaryFormat { } @Override - public Tensor decode(TensorType type, GrowableByteBuffer buffer) { - consumeAndValidateDimensions(type, buffer); + public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) { + TensorType type; + if (optionalType.isPresent()) { + type = optionalType.get(); + consumeAndValidateDimensions(optionalType.get(), buffer); + } + else { + type = decodeType(buffer); + } Tensor.Builder builder = Tensor.Builder.of(type); decodeCells(buffer, builder, type); return builder.build(); @@ -75,6 +82,14 @@ class SparseBinaryFormat implements BinaryFormat { } } + private TensorType decodeType(GrowableByteBuffer buffer) { + int numDimensions = buffer.getInt1_4Bytes(); + TensorType.Builder builder = new TensorType.Builder(); + for (int i = 0; i < numDimensions; ++i) + builder.mapped(buffer.getUtf8String()); + return builder.build(); + } + private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { int numCells = buffer.getInt1_4Bytes(); for (int i = 0; i < numCells; ++i) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 19c1810d928..6413602c532 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -7,6 +7,8 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.Optional; + /** * Class used by clients for serializing a Tensor object into binary format or * de-serializing binary data into a Tensor object. @@ -38,8 +40,15 @@ public class TypedBinaryFormat { return result; } - public static Tensor decode(TensorType type, byte[] data) { - GrowableByteBuffer buffer = GrowableByteBuffer.wrap(data); + /** + * Decode some data to a tensor + * + * @param type the type to decode and validate to, or empty to use the type given in the data + * @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array + * @return the resulting tensor + * @throws IllegalArgumentException if the tensor data was invalid + */ + public static Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer) { int formatType = buffer.getInt1_4Bytes(); switch (formatType) { case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer); |