diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-02 12:17:39 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-02 12:17:39 +0200 |
commit | 6eb80166172e10255841fd3d3cf70bed09d3d8c1 (patch) | |
tree | 0fe5b3ebb646460c70b9222ba6b1505a81d3619f /vespajlib/src/test | |
parent | da41894e5b4f7525ee59d9c69838bdc21735d0f2 (diff) |
Parse tensor value type
Diffstat (limited to 'vespajlib/src/test')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java | 22 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java | 18 |
2 files changed, 24 insertions, 16 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index f7a0a3cdb7d..d3bb702175a 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -58,10 +58,13 @@ public class TensorTypeTestCase { @Test public void requireThatIllegalSyntaxInSpecThrowsException() { - assertIllegalTensorType("foo(x[10])", "Tensor type spec must start with 'tensor(' and end with ')', but was 'foo(x[10])'"); - assertIllegalTensorType("tensor(x_@[10])", "Failed parsing element 'x_@[10]' in type spec 'tensor(x_@[10])'"); - assertIllegalTensorType("tensor(x[10a])", "Failed parsing element 'x[10a]' in type spec 'tensor(x[10a])'"); - assertIllegalTensorType("tensor(x{10})", "Failed parsing element 'x{10}' in type spec 'tensor(x{10})'"); + assertIllegalTensorType("foo(x[10])", "but was 'foo(x[10])'."); + assertIllegalTensorType("tensor(x_@[10])", "Dimension 'x_@[10]' is on the wrong format"); + assertIllegalTensorType("tensor(x[10a])", "Dimension 'x[10a]' is on the wrong format"); + assertIllegalTensorType("tensor(x{10})", "Dimension 'x{10}' is on the wrong format"); + assertIllegalTensorType("tensor<(x{})", " Value type spec must be enclosed in <>"); + assertIllegalTensorType("tensor<>(x{})", "Value type must be"); + assertIllegalTensorType("tensor<notavalue>(x{})", "Value type must be"); } @Test @@ -88,6 +91,13 @@ public class TensorTypeTestCase { assertIsConvertibleTo("tensor(x{},y[10])", "tensor(x{},y[])"); } + @Test + public void testValueType() { + assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); + assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])"); + assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])"); + } + private static void assertTensorType(String typeSpec) { assertTensorType(typeSpec, typeSpec); } @@ -121,4 +131,8 @@ public class TensorTypeTestCase { assertFalse(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType))); } + private void assertValueType(TensorType.Value expectedValueType, String tensorTypeSpec) { + assertEquals(expectedValueType, TensorType.fromSpec(tensorTypeSpec).valueType()); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index e8b17812f32..8fa2537e1e2 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -63,27 +63,21 @@ public class DenseBinaryFormatTestCase { 64, 0, 0, 0, // value 1 64, 64, 0, 0, // value 2 }; - Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); - tensor.type().valueType(TensorType.ValueType.FLOAT); - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(tensor))); + Tensor tensor = Tensor.from("tensor<float>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test public void testSerializationOfDifferentValueTypes() { - assertSerialization(TensorType.ValueType.DOUBLE, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); - assertSerialization(TensorType.ValueType.FLOAT, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); } private void assertSerialization(String tensorString) { - assertSerialization(TensorType.ValueType.DOUBLE, Tensor.from(tensorString)); - } - private void assertSerialization(TensorType.ValueType valueType, String tensorString) { - assertSerialization(valueType, Tensor.from(tensorString)); + assertSerialization(Tensor.from(tensorString)); } - private void assertSerialization(TensorType.ValueType valueType, Tensor tensor) { - tensor.type().valueType(valueType); + private void assertSerialization(Tensor tensor) { assertSerialization(tensor, tensor.type()); } |