diff options
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java | 42 |
1 files changed, 36 insertions, 6 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 50b71024ddf..2a622b73513 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -55,19 +55,19 @@ public class SparseBinaryFormatTestCase { } @Test - public void requireThatSerializationFormatDoNotChange() { + public void requireThatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[] {1, // binary format type 2, // num dimensions 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions 2, // num cells, 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1 - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test - public void requireThatFloatSerializationFormatDoNotChange() { + public void requireThatFloatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[] { 5, // binary format type 1, // float type @@ -76,14 +76,44 @@ public class SparseBinaryFormatTestCase { 2, // num cells, 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1 - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Tensor tensor = Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] { + 5, // binary format type + 2, // bfloat16 type + 2, // num dimensions + 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions + 2, // num cells, + 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, // cell 0 + 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64}; // cell 1 + Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] { + 5, // binary format type + 3, // int8 type + 2, // num dimensions + 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions + 2, // num cells, + 2, (byte)'a', (byte)'b', 1, (byte)'e', 2, // cell 0 + 2, (byte)'c', (byte)'d', 1, (byte)'e', 3}; // cell 1 + Tensor tensor = Tensor.from("tensor<int8>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test public void testSerializationOfDifferentValueTypes() { assertSerialization("tensor<double>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor<float>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<bfloat16>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<int8>(x{},y{}):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); } private void assertSerialization(String tensorString) { |