diff options
author | Lester Solbakken <lesters@oath.com> | 2021-04-08 11:24:52 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-04-08 11:24:52 +0200 |
commit | 049e9a325c8142958909d0464da12a56e5a8f638 (patch) | |
tree | 31d857ec4a5ad3415464e480ae473c39224623b2 /vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java | |
parent | bccd68f8f9a7eb0830d136f8b034ae4f40cc819c (diff) |
Add bfloat16 and int8 tensor cell types in Java
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 5bd1bbdba37..b47c0873535 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -50,6 +50,35 @@ public class TensorTestCase { assertEquals(Tensor.from("tensor<float>(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class); assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[1])")).cell(5.0, 0).build().getClass(), IndexedFloatTensor.class); + + assertEquals(Tensor.from("tensor<bfloat16>(x[1]):[5]").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + + assertEquals(Tensor.from("tensor<int8>(x[1]):[5]").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + } + + private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) { + Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }"); + Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }"); + assertEquals(valueType, t1.multiply(t2).type().valueType()); + assertEquals(valueType, t2.multiply(t1).type().valueType()); + } + + @Test + public void testValueTypeResolving() { + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "float"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16"); + assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8"); } @Test |