diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-09 18:42:13 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-09 18:42:13 +0100 |
commit | 1be23c3c2d98d19a405a7390681a34832b6d6f5f (patch) | |
tree | 25175193e078479f8cf294858a4345e14aedb89e /vespajlib | |
parent | b99468d847b444a3ad7f4aeba0f2eac1906c74cb (diff) |
Add (disabled) dense tensor binary format
Diffstat (limited to 'vespajlib')
9 files changed, 251 insertions, 82 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index daa85cc51e4..7570a357452 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -29,6 +29,14 @@ public final class DimensionSizes { /** Returns the number of dimensions this provides the size of */ public int dimensions() { return sizes.length; } + /** Returns the product of the sizes of this */ + public int totalSize() { + int productSize = 1; + for (int dimensionSize : sizes ) + productSize *= dimensionSize; + return productSize; + } + @Override public boolean equals(Object o) { if (o == this) return true; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 9315922f57a..4654f53647f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -217,13 +217,6 @@ public class IndexedTensor implements Tensor { public abstract Builder cell(double value, int ... indexes); - protected double[] arrayFor(DimensionSizes sizes) { - int productSize = 1; - for (int i = 0; i < sizes.dimensions(); i++ ) - productSize *= sizes.size(i); - return new double[productSize]; - } - @Override public TensorType type() { return type; } @@ -233,7 +226,7 @@ public class IndexedTensor implements Tensor { } /** A bound builder can create the double array directly */ - private static class BoundBuilder extends Builder { + public static class BoundBuilder extends Builder { private DimensionSizes sizes; private double[] values; @@ -242,7 +235,7 @@ public class IndexedTensor implements Tensor { this(type, dimensionSizesOf(type)); } - public static DimensionSizes dimensionSizesOf(TensorType type) { + static DimensionSizes dimensionSizesOf(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < type.dimensions().size(); i++) b.set(i, type.dimensions().get(i).size().get()); @@ -254,7 +247,7 @@ public class IndexedTensor implements Tensor { if ( sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; - values = arrayFor(sizes); + values = new double[sizes.totalSize()]; Arrays.fill(values, Double.NaN); } @@ -290,9 +283,6 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(Cell cell, double value) { - // TODO: Use internal index if applicable - // values[internalIndex] = value; - // return this; int directIndex = cell.getDirectIndex(); if (directIndex >= 0) // optimization values[directIndex] = value; @@ -301,6 +291,15 @@ public class IndexedTensor implements Tensor { return this; } + /** + * Set a cell value by the index in the internal layout of this cell. + * This requires knowledge of the internal layout of cells in this implementation, and should therefore + * probably not be used (but when it can be used it is fast). + */ + public void cellByDirectIndex(int index, double value) { + values[index] = value; + } + } /** @@ -324,7 +323,7 @@ public class IndexedTensor implements Tensor { return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); - double[] values = arrayFor(dimensionSizes); + double[] values = new double[dimensionSizes.totalSize()]; fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java index f3adf63739a..9b0ccdcb6c8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization; import com.google.common.annotations.Beta; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * Representation of a specific binary format with functions for serializing a Tensor object into @@ -21,7 +22,10 @@ interface BinaryFormat { /** * Deserialize the given binary data into a Tensor object. + * + * @param type the expected abstract type of the tensor to serialize + * @param buffer the buffer containing the tensor binary data */ - Tensor decode(GrowableByteBuffer buffer); + Tensor decode(TensorType type, GrowableByteBuffer buffer); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java new file mode 100644 index 00000000000..388e63e6e34 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -0,0 +1,105 @@ +package com.yahoo.tensor.serialization; + +import com.google.common.annotations.Beta; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.text.Utf8; + +import java.util.Iterator; + +/** + * Implementation of a dense binary format for a tensor on the form: + * + * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]* + * Cell_values = [double, double, double, ...]* + * where values are encoded in order of increasing indexes in each dimension, increasing + * indexes of later dimensions in the dimension type before earlier. + * + * @author bratseth + */ +@Beta +public class DenseBinaryFormat implements BinaryFormat { + + @Override + public void encode(GrowableByteBuffer buffer, Tensor tensor) { + if ( ! ( tensor instanceof IndexedTensor)) + throw new RuntimeException("The dense format is only supported for indexed tensors"); + encodeDimensions(buffer, (IndexedTensor)tensor); + encodeCells(buffer, tensor); + } + + private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) { + buffer.putInt1_4Bytes(tensor.type().dimensions().size()); + for (int i = 0; i < tensor.type().dimensions().size(); i++) { + encodeString(buffer, tensor.type().dimensions().get(i).name()); + buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i)); + } + } + + private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { + Iterator<Double> i = tensor.valueIterator(); + if ( ! i.hasNext()) { // no values: Encode as NaN, as 0 dimensions may also mean 1 value + buffer.putDouble(Double.NaN); + } + else { + while (i.hasNext()) + buffer.putDouble(i.next()); + } + } + + private void encodeString(GrowableByteBuffer buffer, String value) { + byte[] stringBytes = Utf8.toBytes(value); + buffer.putInt1_4Bytes(stringBytes.length); + buffer.put(stringBytes); + } + + @Override + public Tensor decode(TensorType type, GrowableByteBuffer buffer) { + DimensionSizes sizes = decodeDimensionSizes(type, buffer); + Tensor.Builder builder = Tensor.Builder.of(type, sizes); + decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder); + return builder.build(); + } + + private DimensionSizes decodeDimensionSizes(TensorType type, GrowableByteBuffer buffer) { + int dimensionCount = buffer.getInt1_4Bytes(); + if (type.dimensions().size() != dimensionCount) + throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + + " dimensions but type is " + type); + + DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount); + for (int i = 0; i < dimensionCount; i++) { + TensorType.Dimension expectedDimension = type.dimensions().get(i); + + String encodedName = decodeString(buffer); + int encodedSize = buffer.getInt1_4Bytes(); + + if ( ! expectedDimension.name().equals(encodedName)) + throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName + + "' as dimension " + i + " but type is " + type); + + if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize) + throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize + + " in " + expectedDimension); + + builder.set(i, encodedSize); + } + return builder.build(); + } + + private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { + for (int i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.getDouble()); + } + + private String decodeString(GrowableByteBuffer buffer) { + int stringLength = buffer.getInt1_4Bytes(); + byte[] stringBytes = new byte[stringLength]; + buffer.get(stringBytes); + return Utf8.toString(stringBytes); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 27a009b5e7e..612df272b6f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -31,14 +31,14 @@ class SparseBinaryFormat implements BinaryFormat { encodeCells(buffer, tensor); } - private static void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) { + private void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) { buffer.putInt1_4Bytes(sortedDimensions.size()); for (TensorType.Dimension dimension : sortedDimensions) { encodeString(buffer, dimension.name()); } } - private static void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { + private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { buffer.putInt1_4Bytes(tensor.size()); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); @@ -47,35 +47,53 @@ class SparseBinaryFormat implements BinaryFormat { } } - private static void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) { + private void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) { for (int i = 0; i < address.size(); i++) encodeString(buffer, address.label(i)); } - private static void encodeString(GrowableByteBuffer buffer, String value) { + private void encodeString(GrowableByteBuffer buffer, String value) { byte[] stringBytes = Utf8.toBytes(value); buffer.putInt1_4Bytes(stringBytes.length); buffer.put(stringBytes); } @Override - public Tensor decode(GrowableByteBuffer buffer) { - TensorType type = decodeDimensions(buffer); + public Tensor decode(TensorType type, GrowableByteBuffer buffer) { + if (type == null) // TODO (January 2017): Remove this when types are available + type = decodeDimensionsToType(buffer); + else + consumeAndValidateDimensions(type, buffer); Tensor.Builder builder = Tensor.Builder.of(type); decodeCells(buffer, builder, type); return builder.build(); } - private static TensorType decodeDimensions(GrowableByteBuffer buffer) { + private TensorType decodeDimensionsToType(GrowableByteBuffer buffer) { TensorType.Builder builder = new TensorType.Builder(); int numDimensions = buffer.getInt1_4Bytes(); for (int i = 0; i < numDimensions; ++i) { - builder.mapped(decodeString(buffer)); // TODO: Support indexed + builder.mapped(decodeString(buffer)); } return builder.build(); } - private static void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { + private void consumeAndValidateDimensions(TensorType type, GrowableByteBuffer buffer) { + int dimensionCount = buffer.getInt1_4Bytes(); + if (type.dimensions().size() != dimensionCount) + throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + + " dimensions but type is " + type); + + for (int i = 0; i < dimensionCount; ++i) { + TensorType.Dimension expectedDimension = type.dimensions().get(i); + String encodedName = decodeString(buffer); + if ( ! expectedDimension.name().equals(encodedName)) + throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName + + "' as dimension " + i + " but type is " + type); + } + } + + private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { int numCells = buffer.getInt1_4Bytes(); for (int i = 0; i < numCells; ++i) { Tensor.Builder.CellBuilder cellBuilder = builder.cell(); @@ -84,7 +102,7 @@ class SparseBinaryFormat implements BinaryFormat { } } - private static void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) { + private void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) { for (TensorType.Dimension dimension : type.dimensions()) { String label = decodeString(buffer); if ( ! label.isEmpty()) { @@ -93,7 +111,7 @@ class SparseBinaryFormat implements BinaryFormat { } } - private static String decodeString(GrowableByteBuffer buffer) { + private String decodeString(GrowableByteBuffer buffer) { int stringLength = buffer.getInt1_4Bytes(); byte[] stringBytes = new byte[stringLength]; buffer.get(stringBytes); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 5a45f20b6d8..65216aa2fcd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -3,7 +3,9 @@ package com.yahoo.tensor.serialization; import com.google.common.annotations.Beta; import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * Class used by clients for serializing a Tensor object into binary format or @@ -18,25 +20,31 @@ import com.yahoo.tensor.Tensor; public class TypedBinaryFormat { private static final int SPARSE_BINARY_FORMAT_TYPE = 1; + private static final int DENSE_BINARY_FORMAT_TYPE = 2; public static byte[] encode(Tensor tensor) { GrowableByteBuffer buffer = new GrowableByteBuffer(); - buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE); - new SparseBinaryFormat().encode(buffer, tensor); + if (tensor instanceof IndexedTensor && 1==2) { // TODO: Activate when we have type information everywhere + buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE); + new DenseBinaryFormat().encode(buffer, tensor); + } + else { + buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE); + new SparseBinaryFormat().encode(buffer, tensor); + } buffer.flip(); byte[] result = new byte[buffer.remaining()]; buffer.get(result); return result; } - public static Tensor decode(byte[] data) { + public static Tensor decode(TensorType type, byte[] data) { GrowableByteBuffer buffer = GrowableByteBuffer.wrap(data); int formatType = buffer.getInt1_4Bytes(); switch (formatType) { - case SPARSE_BINARY_FORMAT_TYPE: - return new SparseBinaryFormat().decode(buffer); - default: - throw new IllegalArgumentException("Binary format type " + formatType + " is not a known format"); + case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer); + case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat().decode(type, buffer); + default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown"); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 2f060239eb1..e2baa1d5ac3 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -27,6 +27,7 @@ public class TensorFunctionBenchmark { modelVectors = modelVectors.stream().map(t -> t.multiply(unitVector("k"))).collect(Collectors.toList()); } dotProduct(queryVector, modelVectors, Math.max(iterations/10, 10)); // warmup + System.gc(); long startTime = System.currentTimeMillis(); dotProduct(queryVector, modelVectors, iterations); long totalTime = System.currentTimeMillis() - startTime; @@ -106,51 +107,41 @@ public class TensorFunctionBenchmark { // ---------------- Mapped with extra space (sidesteps current special-case optimizations): // 410 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(20, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); // 770 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(20, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); // ---------------- Mapped: // 2.6 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time); // 6.8 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations): // 30 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time); // 27 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time); // ---------------- Indexed unbound: // 0.14 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); // 0.14 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed bound: // 0.14 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time); // 0.14 ms - System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java new file mode 100644 index 00000000000..697eb2a7329 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -0,0 +1,56 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.serialization; + +import com.google.common.collect.Sets; +import com.yahoo.tensor.Tensor; +import org.junit.Ignore; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for the dense binary format. + * + * @author bratseth + */ +public class DenseBinaryFormatTestCase { + + @Test + public void testSerialization() { + assertSerialization("{}"); + assertSerialization("{-5.37}"); + assertSerialization("tensor(x[]):{{x:0}:2.0}"); + assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0}"); + assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}"); + } + + @Test + @Ignore // TODO: Activate when encoding in this format is activated + public void requireThatSerializationFormatDoNotChange() { + byte[] encodedTensor = new byte[]{2, // binary format type + 2, // dimension count + 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size + 1, (byte) 'z', 1, // dimension z with size + 64, 0, 0, 0, 0, 0, 0, 0, // value 1 + 64, 8, 0, 0, 0, 0, 0, 0 // value 2 + }; + assertEquals(Arrays.toString(encodedTensor), + Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}")))); + } + + private void assertSerialization(String tensorString) { + assertSerialization(Tensor.from(tensorString)); + } + + private void assertSerialization(Tensor tensor) { + byte[] encodedTensor = TypedBinaryFormat.encode(tensor); + Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor); + assertEquals(tensor, decodedTensor); + } + +} + diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index ad908101329..b314fe06f08 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -13,52 +13,22 @@ import static org.junit.Assert.assertEquals; /** * Tests for the sparse binary format. * - * TODO: When new formats are added we should refactor this test to test all formats - * with the same set of tensor inputs (if feasible). - * * @author geirst */ public class SparseBinaryFormatTestCase { - private static void assertSerialization(String tensorString) { - assertSerialization(Tensor.from(tensorString)); - } - - private static void assertSerialization(String tensorString, Set<String> dimensions) { - Tensor tensor = Tensor.from(tensorString); - assertEquals(dimensions, tensor.type().dimensionNames()); - assertSerialization(tensor); - } - - private static void assertSerialization(Tensor tensor) { - byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(encodedTensor); - assertEquals(tensor, decodedTensor); - } - @Test - public void testSerializationOfTensorsWithDenseTensorAddresses() { - assertSerialization("{}"); - assertSerialization("{{x:0}:2.0}"); - assertSerialization("{{x:0}:2.0,{x:1}:3.0}"); - assertSerialization("{{x:0,y:0}:2.0}"); - assertSerialization("{{x:0,y:0}:2.0,{x:0,y:1}:3.0}"); - assertSerialization("{{y:0,x:0}:2.0}"); - assertSerialization("{{y:0,x:0}:2.0,{y:1,x:0}:3.0}"); - assertSerialization("{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}"); + public void testSerialization() { + assertSerialization("tensor(x{}):{{x:0}:2.0}"); + assertSerialization("tensor(dimX{},dimY{}):{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}"); + assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}"); + assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0,{x:1,y:4}:3.0}"); + assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0}"); + assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}"); } @Test - public void testSerializationOfTensorsWithSparseTensorAddresses() { - assertSerialization("{{x:0}:2.0, {x:1}:3.0}", Sets.newHashSet("x")); - assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}", Sets.newHashSet("x", "y")); - assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0,{x:1,y:4}:3.0}", Sets.newHashSet("x", "y")); - assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0}", Sets.newHashSet("x", "y", "z")); - assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}", Sets.newHashSet("x", "y", "z")); - } - - @Test - public void requireThatCompactSerializationFormatDoNotChange() { + public void requireThatSerializationFormatDoNotChange() { byte[] encodedTensor = new byte[] {1, // binary format type 2, // num dimensions 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions @@ -66,7 +36,17 @@ public class SparseBinaryFormatTestCase { 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1 assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + } + + private void assertSerialization(String tensorString) { + assertSerialization(Tensor.from(tensorString)); + } + + private void assertSerialization(Tensor tensor) { + byte[] encodedTensor = TypedBinaryFormat.encode(tensor); + Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor); + assertEquals(tensor, decodedTensor); } } |