summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java42
1 files changed, 36 insertions, 6 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index 50b71024ddf..2a622b73513 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -55,19 +55,19 @@ public class SparseBinaryFormatTestCase {
}
@Test
- public void requireThatSerializationFormatDoNotChange() {
+ public void requireThatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[] {1, // binary format type
2, // num dimensions
2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
2, // num cells,
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
- public void requireThatFloatSerializationFormatDoNotChange() {
+ public void requireThatFloatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[] {
5, // binary format type
1, // float type
@@ -76,14 +76,44 @@ public class SparseBinaryFormatTestCase {
2, // num cells,
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Tensor tensor = Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {
+ 5, // binary format type
+ 2, // bfloat16 type
+ 2, // num dimensions
+ 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64}; // cell 1
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {
+ 5, // binary format type
+ 3, // int8 type
+ 2, // num dimensions
+ 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 1, (byte)'e', 2, // cell 0
+ 2, (byte)'c', (byte)'d', 1, (byte)'e', 3}; // cell 1
+ Tensor tensor = Tensor.from("tensor<int8>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
public void testSerializationOfDifferentValueTypes() {
assertSerialization("tensor<double>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x{},y{}):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
}
private void assertSerialization(String tensorString) {