summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-16 15:55:41 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-16 15:55:41 +0100
commitf1921848eff763bc99c46e53733df7bcae04fa7b (patch)
treeea7baff225ec91007ff6be8959deee672e71877a /vespajlib
parentbcb0aece3ab9229b2d10169e9b82781cc22d5d2e (diff)
Add tensor document summary field
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java19
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java13
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java5
7 files changed, 70 insertions, 16 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java b/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java
index 85b249432d4..b8dfedc8ede 100644
--- a/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java
+++ b/vespajlib/src/main/java/com/yahoo/io/GrowableBufferOutputStream.java
@@ -9,13 +9,11 @@ import java.util.LinkedList;
import java.util.Iterator;
import java.nio.ByteBuffer;
-
/**
- *
- * @author <a href="mailto:borud@yahoo-inc.com">Bjorn Borud</a>
+ * @author Bjørn Borud
*/
public class GrowableBufferOutputStream extends OutputStream {
-// private static final int MINIMUM_BUFFERSIZE = (64 * 1024);
+
private ByteBuffer lastBuffer;
private ByteBuffer directBuffer;
private LinkedList<ByteBuffer> bufferList = new LinkedList<>();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
index 9b0ccdcb6c8..a6949fdf57f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
@@ -6,6 +6,8 @@ import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.Optional;
+
/**
* Representation of a specific binary format with functions for serializing a Tensor object into
* this format or de-serializing binary data into a Tensor object.
@@ -23,9 +25,9 @@ interface BinaryFormat {
/**
* Deserialize the given binary data into a Tensor object.
*
- * @param type the expected abstract type of the tensor to serialize
+ * @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data
* @param buffer the buffer containing the tensor binary data
*/
- Tensor decode(TensorType type, GrowableByteBuffer buffer);
+ Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 0a97576d5b7..3ff82ea774b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -6,9 +6,9 @@ import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import com.yahoo.text.Utf8;
import java.util.Iterator;
+import java.util.Optional;
/**
* Implementation of a dense binary format for a tensor on the form:
@@ -46,14 +46,23 @@ public class DenseBinaryFormat implements BinaryFormat {
}
@Override
- public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
- DimensionSizes sizes = decodeDimensionSizes(type, buffer);
+ public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) {
+ TensorType type;
+ DimensionSizes sizes;
+ if (optionalType.isPresent()) {
+ type = optionalType.get();
+ sizes = decodeAndValidateDimensionSizes(type, buffer);
+ }
+ else {
+ type = decodeType(buffer);
+ sizes = sizesFromType(type);
+ }
Tensor.Builder builder = Tensor.Builder.of(type, sizes);
decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder);
return builder.build();
}
- private DimensionSizes decodeDimensionSizes(TensorType type, GrowableByteBuffer buffer) {
+ private DimensionSizes decodeAndValidateDimensionSizes(TensorType type, GrowableByteBuffer buffer) {
int dimensionCount = buffer.getInt1_4Bytes();
if (type.dimensions().size() != dimensionCount)
throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
@@ -79,6 +88,22 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
+ private TensorType decodeType(GrowableByteBuffer buffer) {
+ int dimensionCount = buffer.getInt1_4Bytes();
+ TensorType.Builder builder = new TensorType.Builder();
+ for (int i = 0; i < dimensionCount; i++)
+ builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes());
+ return builder.build();
+ }
+
+ /** Returns dimension sizes from a type consisting of fully specified, indexed dimensions only */
+ private DimensionSizes sizesFromType(TensorType type) {
+ DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
+ for (int i = 0; i < type.dimensions().size(); i++)
+ builder.set(i, type.dimensions().get(i).size().get());
+ return builder.build();
+ }
+
private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
for (int i = 0; i < sizes.totalSize(); i++)
builder.cellByDirectIndex(i, buffer.getDouble());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 8ab23c8d77c..6b0443c9bfe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -53,8 +53,15 @@ class SparseBinaryFormat implements BinaryFormat {
}
@Override
- public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
- consumeAndValidateDimensions(type, buffer);
+ public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) {
+ TensorType type;
+ if (optionalType.isPresent()) {
+ type = optionalType.get();
+ consumeAndValidateDimensions(optionalType.get(), buffer);
+ }
+ else {
+ type = decodeType(buffer);
+ }
Tensor.Builder builder = Tensor.Builder.of(type);
decodeCells(buffer, builder, type);
return builder.build();
@@ -75,6 +82,14 @@ class SparseBinaryFormat implements BinaryFormat {
}
}
+ private TensorType decodeType(GrowableByteBuffer buffer) {
+ int numDimensions = buffer.getInt1_4Bytes();
+ TensorType.Builder builder = new TensorType.Builder();
+ for (int i = 0; i < numDimensions; ++i)
+ builder.mapped(buffer.getUtf8String());
+ return builder.build();
+ }
+
private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
int numCells = buffer.getInt1_4Bytes();
for (int i = 0; i < numCells; ++i) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 19c1810d928..6413602c532 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -7,6 +7,8 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.Optional;
+
/**
* Class used by clients for serializing a Tensor object into binary format or
* de-serializing binary data into a Tensor object.
@@ -38,8 +40,15 @@ public class TypedBinaryFormat {
return result;
}
- public static Tensor decode(TensorType type, byte[] data) {
- GrowableByteBuffer buffer = GrowableByteBuffer.wrap(data);
+ /**
+ * Decode some data to a tensor
+ *
+ * @param type the type to decode and validate to, or empty to use the type given in the data
+ * @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array
+ * @return the resulting tensor
+ * @throws IllegalArgumentException if the tensor data was invalid
+ */
+ public static Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer) {
int formatType = buffer.getInt1_4Bytes();
switch (formatType) {
case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 15e82e6b15c..8a3d2879201 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -2,11 +2,13 @@
package com.yahoo.tensor.serialization;
import com.google.common.collect.Sets;
+import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
import org.junit.Ignore;
import org.junit.Test;
import java.util.Arrays;
+import java.util.Optional;
import java.util.Set;
import static org.junit.Assert.assertEquals;
@@ -46,7 +48,7 @@ public class DenseBinaryFormatTestCase {
private void assertSerialization(Tensor tensor) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor);
+ Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(tensor.type()), GrowableByteBuffer.wrap(encodedTensor));
assertEquals(tensor, decodedTensor);
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index 283aa90cf65..65f6b92f91e 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -2,10 +2,12 @@
package com.yahoo.tensor.serialization;
import com.google.common.collect.Sets;
+import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
import org.junit.Test;
import java.util.Arrays;
+import java.util.Optional;
import java.util.Set;
import static org.junit.Assert.assertEquals;
@@ -46,7 +48,8 @@ public class SparseBinaryFormatTestCase {
private void assertSerialization(Tensor tensor) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor);
+ Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(tensor.type()),
+ GrowableByteBuffer.wrap(encodedTensor));
assertEquals(tensor, decodedTensor);
}