summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
committerLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
commit049e9a325c8142958909d0464da12a56e5a8f638 (patch)
tree31d857ec4a5ad3415464e480ae473c39224623b2 /vespajlib/src/test/java/com
parentbccd68f8f9a7eb0830d136f8b034ae4f40cc819c (diff)
Add bfloat16 and int8 tensor cell types in Java
Diffstat (limited to 'vespajlib/src/test/java/com')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java29
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java39
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java15
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java62
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java42
6 files changed, 183 insertions, 8 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 5bd1bbdba37..b47c0873535 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -50,6 +50,35 @@ public class TensorTestCase {
assertEquals(Tensor.from("tensor<float>(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class);
assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[1])")).cell(5.0, 0).build().getClass(),
IndexedFloatTensor.class);
+
+ assertEquals(Tensor.from("tensor<bfloat16>(x[1]):[5]").getClass(), IndexedFloatTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedFloatTensor.class);
+
+ assertEquals(Tensor.from("tensor<int8>(x[1]):[5]").getClass(), IndexedFloatTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedFloatTensor.class);
+ }
+
+ private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) {
+ Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }");
+ Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }");
+ assertEquals(valueType, t1.multiply(t2).type().valueType());
+ assertEquals(valueType, t2.multiply(t1).type().valueType());
+ }
+
+ @Test
+ public void testValueTypeResolving() {
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "float");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8");
+ assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16");
+ assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8");
+ assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8");
}
@Test
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index a547f941d8e..caa125dfef7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -96,8 +96,12 @@ public class TensorTypeTestCase {
assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
+ assertValueType(TensorType.Value.BFLOAT16, "tensor<bfloat16>(x[])");
+ assertValueType(TensorType.Value.INT8, "tensor<int8>(x[])");
assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString());
assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString());
+ assertEquals("tensor<bfloat16>(x[])", TensorType.fromSpec("tensor<bfloat16>(x[])").toString());
+ assertEquals("tensor<int8>(x[])", TensorType.fromSpec("tensor<int8>(x[])").toString());
}
private static void assertTensorType(String typeSpec) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 5d1bc7b0c3f..3c79b0c769c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -41,7 +41,7 @@ public class DenseBinaryFormatTestCase {
}
@Test
- public void requireThatDefaultSerializationFormatDoNotChange() {
+ public void requireThatDefaultSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[]{2, // binary format type
2, // dimension count
2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
@@ -54,7 +54,7 @@ public class DenseBinaryFormatTestCase {
}
@Test
- public void requireThatFloatSerializationFormatDoNotChange() {
+ public void requireThatFloatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[]{6, // binary format type
1, // float type
2, // dimension count
@@ -68,9 +68,44 @@ public class DenseBinaryFormatTestCase {
}
@Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[]{6, // binary format type
+ 2, // bfloat16 type
+ 2, // dimension count
+ 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
+ 1, (byte) 'z', 1, // dimension z with size
+ 64, 0, // value 1
+ 64, 64, // value 2
+ };
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[]{6, // binary format type
+ 3, // int8 type
+ 2, // dimension count
+ 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
+ 1, (byte) 'z', 1, // dimension z with size
+ 2, // value 1
+ 3, // value 2
+ };
+ Tensor tensor = Tensor.from("tensor<int8>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
public void testSerializationOfDifferentValueTypes() {
+ assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x[],y[]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
+ assertSerialization("tensor<double>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]");
+ assertSerialization("tensor<float>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]");
+ assertSerialization("tensor<bfloat16>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]");
+ assertSerialization("tensor<int8>(x[2],y[2]):[2, 3, 4, 5]");
}
private void assertSerialization(String tensorString) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 81de8a9db4c..3ca20661587 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -134,4 +134,19 @@ public class JsonFormatTestCase {
}
}
+ private void assertEncodeDecode(Tensor tensor) {
+ Tensor decoded = JsonFormat.decode(tensor.type(), JsonFormat.encodeWithType(tensor));
+ assertEquals(tensor, decoded);
+ assertEquals(tensor.type(), decoded.type());
+ }
+
+ @Test
+ public void testTensorCellTypes() {
+ assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<double>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<float>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<bfloat16>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<int8>(x[2],y[2]):[2,3,5,8]"));
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
index 69ef4922d8d..e9f8c81f21b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
@@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import java.util.Arrays;
import java.util.Optional;
import static org.junit.Assert.assertEquals;
@@ -78,9 +79,70 @@ public class MixedBinaryFormatTestCase {
}
@Test
+ public void requireThatDefaultSerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {3, // binary format type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1
+ Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatFloatSerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {7, // binary format type
+ 1, // float type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 64, 0, 0, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 64, 64, 0, 0}; // cell 1
+ Tensor tensor = Tensor.from("tensor<float>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {7, // binary format type
+ 2, // bfloat16 type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 64, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 64, 64}; // cell 1
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {7, // binary format type
+ 3, // int8 type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 2, // cell 0
+ 2, (byte)'c', (byte)'d', 3}; // cell 1
+ Tensor tensor = Tensor.from("tensor<int8>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
public void testSerializationOfDifferentValueTypes() {
assertSerialization("tensor<double>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x{},y[2]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
}
private void assertSerialization(String tensorString) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index 50b71024ddf..2a622b73513 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -55,19 +55,19 @@ public class SparseBinaryFormatTestCase {
}
@Test
- public void requireThatSerializationFormatDoNotChange() {
+ public void requireThatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[] {1, // binary format type
2, // num dimensions
2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
2, // num cells,
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
- public void requireThatFloatSerializationFormatDoNotChange() {
+ public void requireThatFloatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[] {
5, // binary format type
1, // float type
@@ -76,14 +76,44 @@ public class SparseBinaryFormatTestCase {
2, // num cells,
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Tensor tensor = Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {
+ 5, // binary format type
+ 2, // bfloat16 type
+ 2, // num dimensions
+ 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64}; // cell 1
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {
+ 5, // binary format type
+ 3, // int8 type
+ 2, // num dimensions
+ 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 1, (byte)'e', 2, // cell 0
+ 2, (byte)'c', (byte)'d', 1, (byte)'e', 3}; // cell 1
+ Tensor tensor = Tensor.from("tensor<int8>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
public void testSerializationOfDifferentValueTypes() {
assertSerialization("tensor<double>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x{},y{}):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
}
private void assertSerialization(String tensorString) {