summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
committerLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
commit049e9a325c8142958909d0464da12a56e5a8f638 (patch)
tree31d857ec4a5ad3415464e480ae473c39224623b2 /vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
parentbccd68f8f9a7eb0830d136f8b034ae4f40cc819c (diff)
Add bfloat16 and int8 tensor cell types in Java
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java29
1 files changed, 29 insertions, 0 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 5bd1bbdba37..b47c0873535 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -50,6 +50,35 @@ public class TensorTestCase {
assertEquals(Tensor.from("tensor<float>(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class);
assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[1])")).cell(5.0, 0).build().getClass(),
IndexedFloatTensor.class);
+
+ assertEquals(Tensor.from("tensor<bfloat16>(x[1]):[5]").getClass(), IndexedFloatTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedFloatTensor.class);
+
+ assertEquals(Tensor.from("tensor<int8>(x[1]):[5]").getClass(), IndexedFloatTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedFloatTensor.class);
+ }
+
+ private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) {
+ Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }");
+ Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }");
+ assertEquals(valueType, t1.multiply(t2).type().valueType());
+ assertEquals(valueType, t2.multiply(t1).type().valueType());
+ }
+
+ @Test
+ public void testValueTypeResolving() {
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "float");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8");
+ assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16");
+ assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8");
+ assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8");
}
@Test