summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2017-01-10 20:16:39 +0100
committerGitHub <noreply@github.com>2017-01-10 20:16:39 +0100
commit2676b9c9987193f6c48677e30357123c4164eccc (patch)
treeaf9c84f76951392adc50c8dd4a552775d61cee36 /vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
parent697fa08ac031476eee46f7b9cc8b902617ad2371 (diff)
Revert "Revert "Add (disabled) dense tensor binary format""
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java87
1 files changed, 87 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
new file mode 100644
index 00000000000..0a97576d5b7
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -0,0 +1,87 @@
+package com.yahoo.tensor.serialization;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.text.Utf8;
+
+import java.util.Iterator;
+
+/**
+ * Implementation of a dense binary format for a tensor on the form:
+ *
+ * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]*
+ * Cell_values = [double, double, double, ...]*
+ * where values are encoded in order of increasing indexes in each dimension, increasing
+ * indexes of later dimensions in the dimension type before earlier.
+ *
+ * @author bratseth
+ */
+@Beta
+public class DenseBinaryFormat implements BinaryFormat {
+
+ @Override
+ public void encode(GrowableByteBuffer buffer, Tensor tensor) {
+ if ( ! ( tensor instanceof IndexedTensor))
+ throw new RuntimeException("The dense format is only supported for indexed tensors");
+ encodeDimensions(buffer, (IndexedTensor)tensor);
+ encodeCells(buffer, tensor);
+ }
+
+ private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) {
+ buffer.putInt1_4Bytes(tensor.type().dimensions().size());
+ for (int i = 0; i < tensor.type().dimensions().size(); i++) {
+ buffer.putUtf8String(tensor.type().dimensions().get(i).name());
+ buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i));
+ }
+ }
+
+ private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
+ Iterator<Double> i = tensor.valueIterator();
+ while (i.hasNext())
+ buffer.putDouble(i.next());
+ }
+
+ @Override
+ public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
+ DimensionSizes sizes = decodeDimensionSizes(type, buffer);
+ Tensor.Builder builder = Tensor.Builder.of(type, sizes);
+ decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder);
+ return builder.build();
+ }
+
+ private DimensionSizes decodeDimensionSizes(TensorType type, GrowableByteBuffer buffer) {
+ int dimensionCount = buffer.getInt1_4Bytes();
+ if (type.dimensions().size() != dimensionCount)
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
+ " dimensions but type is " + type);
+
+ DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount);
+ for (int i = 0; i < dimensionCount; i++) {
+ TensorType.Dimension expectedDimension = type.dimensions().get(i);
+
+ String encodedName = buffer.getUtf8String();
+ int encodedSize = buffer.getInt1_4Bytes();
+
+ if ( ! expectedDimension.name().equals(encodedName))
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
+ "' as dimension " + i + " but type is " + type);
+
+ if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize)
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize +
+ " in " + expectedDimension + " in type " + type);
+
+ builder.set(i, encodedSize);
+ }
+ return builder.build();
+ }
+
+ private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ for (int i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.getDouble());
+ }
+
+}