From 049e9a325c8142958909d0464da12a56e5a8f638 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 8 Apr 2021 11:24:52 +0200 Subject: Add bfloat16 and int8 tensor cell types in Java --- .../test/java/com/yahoo/tensor/TensorTestCase.java | 29 ++++++++++ .../java/com/yahoo/tensor/TensorTypeTestCase.java | 4 ++ .../serialization/DenseBinaryFormatTestCase.java | 39 +++++++++++++- .../tensor/serialization/JsonFormatTestCase.java | 15 ++++++ .../serialization/MixedBinaryFormatTestCase.java | 62 ++++++++++++++++++++++ .../serialization/SparseBinaryFormatTestCase.java | 42 ++++++++++++--- 6 files changed, 183 insertions(+), 8 deletions(-) (limited to 'vespajlib/src/test/java/com/yahoo/tensor') diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 5bd1bbdba37..b47c0873535 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -50,6 +50,35 @@ public class TensorTestCase { assertEquals(Tensor.from("tensor(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class); assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(5.0, 0).build().getClass(), IndexedFloatTensor.class); + + assertEquals(Tensor.from("tensor(x[1]):[5]").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + + assertEquals(Tensor.from("tensor(x[1]):[5]").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + } + + private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) { + Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }"); + Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }"); + assertEquals(valueType, t1.multiply(t2).type().valueType()); + assertEquals(valueType, t2.multiply(t1).type().valueType()); + } + + @Test + public void testValueTypeResolving() { + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "float"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16"); + assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8"); } @Test diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index a547f941d8e..caa125dfef7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -96,8 +96,12 @@ public class TensorTypeTestCase { assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); assertValueType(TensorType.Value.FLOAT, "tensor(x[])"); + assertValueType(TensorType.Value.BFLOAT16, "tensor(x[])"); + assertValueType(TensorType.Value.INT8, "tensor(x[])"); assertEquals("tensor(x[])", TensorType.fromSpec("tensor(x[])").toString()); assertEquals("tensor(x[])", TensorType.fromSpec("tensor(x[])").toString()); + assertEquals("tensor(x[])", TensorType.fromSpec("tensor(x[])").toString()); + assertEquals("tensor(x[])", TensorType.fromSpec("tensor(x[])").toString()); } private static void assertTensorType(String typeSpec) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 5d1bc7b0c3f..3c79b0c769c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -41,7 +41,7 @@ public class DenseBinaryFormatTestCase { } @Test - public void requireThatDefaultSerializationFormatDoNotChange() { + public void requireThatDefaultSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[]{2, // binary format type 2, // dimension count 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size @@ -54,7 +54,7 @@ public class DenseBinaryFormatTestCase { } @Test - public void requireThatFloatSerializationFormatDoNotChange() { + public void requireThatFloatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[]{6, // binary format type 1, // float type 2, // dimension count @@ -67,10 +67,45 @@ public class DenseBinaryFormatTestCase { assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } + @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[]{6, // binary format type + 2, // bfloat16 type + 2, // dimension count + 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size + 1, (byte) 'z', 1, // dimension z with size + 64, 0, // value 1 + 64, 64, // value 2 + }; + Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[]{6, // binary format type + 3, // int8 type + 2, // dimension count + 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size + 1, (byte) 'z', 1, // dimension z with size + 2, // value 1 + 3, // value 2 + }; + Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + @Test public void testSerializationOfDifferentValueTypes() { + assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); + assertSerialization("tensor(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]"); + assertSerialization("tensor(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]"); + assertSerialization("tensor(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]"); + assertSerialization("tensor(x[2],y[2]):[2, 3, 4, 5]"); } private void assertSerialization(String tensorString) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 81de8a9db4c..3ca20661587 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -134,4 +134,19 @@ public class JsonFormatTestCase { } } + private void assertEncodeDecode(Tensor tensor) { + Tensor decoded = JsonFormat.decode(tensor.type(), JsonFormat.encodeWithType(tensor)); + assertEquals(tensor, decoded); + assertEquals(tensor.type(), decoded.type()); + } + + @Test + public void testTensorCellTypes() { + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2,3,5,8]")); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java index 69ef4922d8d..e9f8c81f21b 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; +import java.util.Arrays; import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -77,10 +78,71 @@ public class MixedBinaryFormatTestCase { assertSerialization(tensor); } + @Test + public void requireThatDefaultSerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {3, // binary format type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0 + 2, (byte)'c', (byte)'d', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatFloatSerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {7, // binary format type + 1, // float type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 64, 0, 0, 0, // cell 0 + 2, (byte)'c', (byte)'d', 64, 64, 0, 0}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {7, // binary format type + 2, // bfloat16 type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 64, 0, // cell 0 + 2, (byte)'c', (byte)'d', 64, 64}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {7, // binary format type + 3, // int8 type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 2, // cell 0 + 2, (byte)'c', (byte)'d', 3}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + @Test public void testSerializationOfDifferentValueTypes() { assertSerialization("tensor(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x{},y[2]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); } private void assertSerialization(String tensorString) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 50b71024ddf..2a622b73513 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -55,19 +55,19 @@ public class SparseBinaryFormatTestCase { } @Test - public void requireThatSerializationFormatDoNotChange() { + public void requireThatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[] {1, // binary format type 2, // num dimensions 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions 2, // num cells, 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1 - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test - public void requireThatFloatSerializationFormatDoNotChange() { + public void requireThatFloatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[] { 5, // binary format type 1, // float type @@ -76,14 +76,44 @@ public class SparseBinaryFormatTestCase { 2, // num cells, 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1 - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] { + 5, // binary format type + 2, // bfloat16 type + 2, // num dimensions + 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions + 2, // num cells, + 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, // cell 0 + 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] { + 5, // binary format type + 3, // int8 type + 2, // num dimensions + 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions + 2, // num cells, + 2, (byte)'a', (byte)'b', 1, (byte)'e', 2, // cell 0 + 2, (byte)'c', (byte)'d', 1, (byte)'e', 3}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test public void testSerializationOfDifferentValueTypes() { assertSerialization("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x{},y{}):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); } private void assertSerialization(String tensorString) { -- cgit v1.2.3