aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-02 12:17:39 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-02 12:17:39 +0200
commit6eb80166172e10255841fd3d3cf70bed09d3d8c1 (patch)
tree0fe5b3ebb646460c70b9222ba6b1505a81d3619f /vespajlib/src/test
parentda41894e5b4f7525ee59d9c69838bdc21735d0f2 (diff)
Parse tensor value type
Diffstat (limited to 'vespajlib/src/test')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java18
2 files changed, 24 insertions, 16 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index f7a0a3cdb7d..d3bb702175a 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -58,10 +58,13 @@ public class TensorTypeTestCase {
@Test
public void requireThatIllegalSyntaxInSpecThrowsException() {
- assertIllegalTensorType("foo(x[10])", "Tensor type spec must start with 'tensor(' and end with ')', but was 'foo(x[10])'");
- assertIllegalTensorType("tensor(x_@[10])", "Failed parsing element 'x_@[10]' in type spec 'tensor(x_@[10])'");
- assertIllegalTensorType("tensor(x[10a])", "Failed parsing element 'x[10a]' in type spec 'tensor(x[10a])'");
- assertIllegalTensorType("tensor(x{10})", "Failed parsing element 'x{10}' in type spec 'tensor(x{10})'");
+ assertIllegalTensorType("foo(x[10])", "but was 'foo(x[10])'.");
+ assertIllegalTensorType("tensor(x_@[10])", "Dimension 'x_@[10]' is on the wrong format");
+ assertIllegalTensorType("tensor(x[10a])", "Dimension 'x[10a]' is on the wrong format");
+ assertIllegalTensorType("tensor(x{10})", "Dimension 'x{10}' is on the wrong format");
+ assertIllegalTensorType("tensor<(x{})", " Value type spec must be enclosed in <>");
+ assertIllegalTensorType("tensor<>(x{})", "Value type must be");
+ assertIllegalTensorType("tensor<notavalue>(x{})", "Value type must be");
}
@Test
@@ -88,6 +91,13 @@ public class TensorTypeTestCase {
assertIsConvertibleTo("tensor(x{},y[10])", "tensor(x{},y[])");
}
+ @Test
+ public void testValueType() {
+ assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
+ assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
+ assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
+ }
+
private static void assertTensorType(String typeSpec) {
assertTensorType(typeSpec, typeSpec);
}
@@ -121,4 +131,8 @@ public class TensorTypeTestCase {
assertFalse(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType)));
}
+ private void assertValueType(TensorType.Value expectedValueType, String tensorTypeSpec) {
+ assertEquals(expectedValueType, TensorType.fromSpec(tensorTypeSpec).valueType());
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index e8b17812f32..8fa2537e1e2 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -63,27 +63,21 @@ public class DenseBinaryFormatTestCase {
64, 0, 0, 0, // value 1
64, 64, 0, 0, // value 2
};
- Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
- tensor.type().valueType(TensorType.ValueType.FLOAT);
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ Tensor tensor = Tensor.from("tensor<float>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
public void testSerializationOfDifferentValueTypes() {
- assertSerialization(TensorType.ValueType.DOUBLE, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
- assertSerialization(TensorType.ValueType.FLOAT, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
}
private void assertSerialization(String tensorString) {
- assertSerialization(TensorType.ValueType.DOUBLE, Tensor.from(tensorString));
- }
- private void assertSerialization(TensorType.ValueType valueType, String tensorString) {
- assertSerialization(valueType, Tensor.from(tensorString));
+ assertSerialization(Tensor.from(tensorString));
}
- private void assertSerialization(TensorType.ValueType valueType, Tensor tensor) {
- tensor.type().valueType(valueType);
+ private void assertSerialization(Tensor tensor) {
assertSerialization(tensor, tensor.type());
}