diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-03-29 12:21:56 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-03-29 12:21:56 +0200 |
commit | e69d6e8f3d8a6504135f6d2733a3a42f6a041ed4 (patch) | |
tree | 046483fb628977f62a66cb660d4a09fcd4302e0d /vespajlib/src/test/java/com/yahoo/tensor | |
parent | 13100e8dcc72b7c879727e5d96e1fdfceb2d3bcc (diff) |
Validate query feature tensor types
- Validate tensor feature types when a tensor is set programmatically.
- Add a toShortString for messages containing tensors.
- Consistent and nicer spacing in tensor string forms.
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor')
4 files changed, 71 insertions, 21 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java index 7bb02f03735..ba814f7ad54 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java @@ -35,7 +35,7 @@ public class MappedTensorTestCase { cell().label("x", "0").value(1). cell().label("x", "1").value(2).build(); assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames()); - assertEquals("tensor(x{}):{0:1.0,1:2.0}", tensor.toString()); + assertEquals("tensor(x{}):{0:1.0, 1:2.0}", tensor.toString()); } @Test @@ -45,7 +45,7 @@ public class MappedTensorTestCase { cell().label("x", "0").label("y", "0").value(1). cell().label("x", "1").label("y", "0").value(2).build(); assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames()); - assertEquals("tensor(x{},y{}):{{x:0,y:0}:1.0,{x:1,y:0}:2.0}", tensor.toString()); + assertEquals("tensor(x{},y{}):{{x:0,y:0}:1.0, {x:1,y:0}:2.0}", tensor.toString()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java index 50f2bc5efff..a26e56c4468 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java @@ -69,7 +69,7 @@ public class MixedTensorTestCase { cell().label("x", "1").value(2). build(); assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames()); - assertEquals("tensor(x{}):{0:1.0,1:2.0}", + assertEquals("tensor(x{}):{0:1.0, 1:2.0}", tensor.toString()); } @@ -84,7 +84,7 @@ public class MixedTensorTestCase { cell().label("x", "1").label("y", "2").value(6). build(); assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames()); - assertEquals("tensor(x{},y{}):{{x:0,y:0}:1.0,{x:0,y:1}:2.0,{x:1,y:0}:4.0,{x:1,y:1}:5.0,{x:1,y:2}:6.0}", + assertEquals("tensor(x{},y{}):{{x:0,y:0}:1.0, {x:0,y:1}:2.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0, {x:1,y:2}:6.0}", tensor.toString()); } @@ -100,7 +100,7 @@ public class MixedTensorTestCase { cell().label("x", "2").label("y", 2).value(6). build(); assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames()); - assertEquals("tensor(x{},y[3]):{1:[1.0, 2.0, 0.0],2:[4.0, 5.0, 6.0]}", + assertEquals("tensor(x{},y[3]):{1:[1.0, 2.0, 0.0], 2:[4.0, 5.0, 6.0]}", tensor.toString()); } @@ -122,7 +122,9 @@ public class MixedTensorTestCase { cell().label("x", "x2").label("y", 2).label("z","z2").value(16). build(); assertEquals(Sets.newHashSet("x", "y", "z"), tensor.type().dimensionNames()); - assertEquals("tensor(x{},y[3],z{}):{{x:x1,y:0,z:z1}:1.0,{x:x1,y:0,z:z2}:2.0,{x:x1,y:1,z:z1}:3.0,{x:x1,y:1,z:z2}:4.0,{x:x1,y:2,z:z1}:5.0,{x:x1,y:2,z:z2}:6.0,{x:x2,y:0,z:z1}:11.0,{x:x2,y:0,z:z2}:12.0,{x:x2,y:1,z:z1}:13.0,{x:x2,y:1,z:z2}:14.0,{x:x2,y:2,z:z1}:15.0,{x:x2,y:2,z:z2}:16.0}", + assertEquals("tensor(x{},y[3],z{}):{{x:x1,y:0,z:z1}:1.0, {x:x1,y:0,z:z2}:2.0, {x:x1,y:1,z:z1}:3.0, " + + "{x:x1,y:1,z:z2}:4.0, {x:x1,y:2,z:z1}:5.0, {x:x1,y:2,z:z2}:6.0, {x:x2,y:0,z:z1}:11.0, " + + "{x:x2,y:0,z:z2}:12.0, {x:x2,y:1,z:z1}:13.0, {x:x2,y:1,z:z2}:14.0, {x:x2,y:2,z:z1}:15.0, {x:x2,y:2,z:z2}:16.0}", tensor.toString()); } @@ -148,7 +150,11 @@ public class MixedTensorTestCase { cell().label("i", "b").label("k","d").label("j",1).label("l",1).value(16). build(); assertEquals(Sets.newHashSet("i", "j", "k", "l"), tensor.type().dimensionNames()); - assertEquals("tensor(i{},j[2],k{},l[2]):{{i:a,j:0,k:c,l:0}:1.0,{i:a,j:0,k:c,l:1}:2.0,{i:a,j:0,k:d,l:0}:5.0,{i:a,j:0,k:d,l:1}:6.0,{i:a,j:1,k:c,l:0}:3.0,{i:a,j:1,k:c,l:1}:4.0,{i:a,j:1,k:d,l:0}:7.0,{i:a,j:1,k:d,l:1}:8.0,{i:b,j:0,k:c,l:0}:9.0,{i:b,j:0,k:c,l:1}:10.0,{i:b,j:0,k:d,l:0}:13.0,{i:b,j:0,k:d,l:1}:14.0,{i:b,j:1,k:c,l:0}:11.0,{i:b,j:1,k:c,l:1}:12.0,{i:b,j:1,k:d,l:0}:15.0,{i:b,j:1,k:d,l:1}:16.0}", + assertEquals("tensor(i{},j[2],k{},l[2]):{{i:a,j:0,k:c,l:0}:1.0, {i:a,j:0,k:c,l:1}:2.0, " + + "{i:a,j:0,k:d,l:0}:5.0, {i:a,j:0,k:d,l:1}:6.0, {i:a,j:1,k:c,l:0}:3.0, {i:a,j:1,k:c,l:1}:4.0, " + + "{i:a,j:1,k:d,l:0}:7.0, {i:a,j:1,k:d,l:1}:8.0, {i:b,j:0,k:c,l:0}:9.0, {i:b,j:0,k:c,l:1}:10.0, " + + "{i:b,j:0,k:d,l:0}:13.0, {i:b,j:0,k:d,l:1}:14.0, {i:b,j:1,k:c,l:0}:11.0, {i:b,j:1,k:c,l:1}:12.0, "+ + "{i:b,j:1,k:d,l:0}:15.0, {i:b,j:1,k:d,l:1}:16.0}", tensor.toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index fd33cf97220..2067d7a8492 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -31,17 +31,61 @@ import static org.junit.Assert.fail; public class TensorTestCase { @Test - public void testStringForm() { - assertEquals("tensor():{5.7}", Tensor.from("{5.7}").toString()); + public void testFactory() { assertTrue(Tensor.from("tensor():{5.7}") instanceof IndexedTensor); - assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString()); - assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString()); + } + + @Test + public void testToString() { + assertEquals("tensor():{5.7}", Tensor.from("{5.7}").toString()); + assertEquals("tensor(x[3]):[0.1, 0.2, 0.3]", + Tensor.from("tensor(x[3]):[0.1, 0.2, 0.3]").toString()); + assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:6.0}", + Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString()); + assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:-5.3, {d1:l1,d2:l2}:0.0}", + Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString()); + assertEquals("tensor(m{},x[3]):{k1:[0.0, 1.0, 2.0], k2:[0.0, 1.0, 2.0], k3:[0.0, 1.0, 2.0], k4:[0.0, 1.0, 2.0]}", + Tensor.from("tensor(m{},x[3]):{k1:[0,1,2], k2:[0,1,2], k3:[0,1,2], k4:[0,1,2]}").toString()); + assertEquals("tensor(m{},n{},x[3]):" + + "{{m:k1,n:k1,x:0}:0.0, {m:k1,n:k1,x:1}:1.0, {m:k1,n:k1,x:2}:2.0," + + " {m:k2,n:k1,x:0}:0.0, {m:k2,n:k1,x:1}:1.0, {m:k2,n:k1,x:2}:2.0," + + " {m:k3,n:k1,x:0}:0.0, {m:k3,n:k1,x:1}:1.0, {m:k3,n:k1,x:2}:2.0}", + Tensor.from("tensor(m{},n{},x[3]):" + + "{{m:k1,n:k1,x:0}:0, {m:k1,n:k1,x:1}:1, {m:k1,n:k1,x:2}:2, " + + " {m:k2,n:k1,x:0}:0, {m:k2,n:k1,x:1}:1, {m:k2,n:k1,x:2}:2, " + + " {m:k3,n:k1,x:0}:0, {m:k3,n:k1,x:1}:1, {m:k3,n:k1,x:2}:2}").toString()); + assertEquals("tensor(m{},x[2],y[2]):" + + "{k1:[[0.0, 1.0], [2.0, 3.0]], k2:[[0.0, 1.0], [2.0, 3.0]], k3:[[0.0, 1.0], [2.0, 3.0]]}", + Tensor.from("tensor(m{},x[2],y[2]):{k1:[[0,1],[2,3]], k2:[[0,1],[2,3]], k3:[[0,1],[2,3]]}").toString()); assertEquals("Labels are quoted when necessary", - "tensor(d1{}):{\"'''\":6.0,'[[\":\"]]':5.0}", + "tensor(d1{}):{\"'''\":6.0, '[[\":\"]]':5.0}", Tensor.from("{ {d1:'[[\":\"]]'}: 5, {d1:\"'''\"}:6.0 }").toString()); } @Test + public void testToShortString() { + assertEquals("tensor(x[10]):[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]", + Tensor.from("tensor(x[10]):[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]").toShortString()); + assertEquals("tensor(x[14]):[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, ...]", + Tensor.from("tensor(x[14]):[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]").toShortString()); + assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0, {d1:l1,d2:l3}:6.0, ...}", + Tensor.from("{{d1:l1,d2:l1}:6, {d2:l2,d1:l1}:6, {d2:l3,d1:l1}:6, {d2:l4,d1:l1}:6, {d2:l5,d1:l1}:6," + + " {d2:l6,d1:l1}:6, {d2:l7,d1:l1}:6, {d2:l8,d1:l1}:6, {d2:l9,d1:l1}:6, {d2:l2,d1:l2}:6," + + " {d2:l2,d1:l3}:6, {d2:l2,d1:l4}:6}").toShortString()); + assertEquals("tensor(m{},x[3]):{k1:[0.0, 1.0, 2.0], k2:[0.0, 1.0, ...}", + Tensor.from("tensor(m{},x[3]):{k1:[0,1,2], k2:[0,1,2], k3:[0,1,2], k4:[0,1,2]}").toShortString()); + assertEquals("tensor(m{},x[3]):{k1:[0.0, 1.0, 2.0], k2:[0.0, 1.0, ...}", + Tensor.from("tensor(m{},x[3]):{k1:[0,1,2], k2:[0,1,2], k3:[0,1,2], k4:[0,1,2]}").toShortString()); + assertEquals("tensor(m{},n{},x[3]):{{m:k1,n:k1,x:0}:0.0, {m:k1,n:k1,x:1}:1.0, {m:k1,n:k1,x:2}:2.0, ...}", + Tensor.from("tensor(m{},n{},x[3]):" + + "{{m:k1,n:k1,x:0}:0, {m:k1,n:k1,x:1}:1, {m:k1,n:k1,x:2}:2, " + + " {m:k2,n:k1,x:0}:0, {m:k2,n:k1,x:1}:1, {m:k2,n:k1,x:2}:2, " + + " {m:k3,n:k1,x:0}:0, {m:k3,n:k1,x:1}:1, {m:k3,n:k1,x:2}:2}").toShortString()); + assertEquals("tensor(m{},x[2],y[2]):{k1:[[0.0, 1.0], [2.0, 3.0]], k2:[[0.0, ...}", + Tensor.from("tensor(m{},x[2],y[2]):{k1:[[0,1],[2,3]], k2:[[0,1],[2,3]], k3:[[0,1],[2,3]]}").toShortString()); + } + + @Test public void testValueTypes() { assertEquals(Tensor.from("tensor<double>(x[1]):{{x:0}:5}").getClass(), IndexedDoubleTensor.class); assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<double>(x[1])")).cell(5.0, 0).build().getClass(), @@ -60,13 +104,6 @@ public class TensorTestCase { IndexedFloatTensor.class); } - private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) { - Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }"); - Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }"); - assertEquals(valueType, t1.multiply(t2).type().valueType()); - assertEquals(valueType, t2.multiply(t1).type().valueType()); - } - @Test public void testValueTypeResolving() { assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double"); @@ -319,6 +356,13 @@ public class TensorTestCase { "tensor(x[2],y[2]):[[4,2],[3,4]]"); } + private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) { + Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }"); + Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }"); + assertEquals(valueType, t1.multiply(t2).type().valueType()); + assertEquals(valueType, t2.multiply(t1).type().valueType()); + } + private void assertLargest(String expectedCells, String tensorString) { Tensor tensor = Tensor.from(tensorString); assertEquals(expectedCells, asString(tensor.largest(), tensor.type())); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index ce165474a53..738213ecb97 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -19,9 +19,9 @@ public class TensorFunctionTestCase { new L1Normalize<>(new ConstantTensor<>("{{x:1}:1.0}"), "x")); assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); - assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max, x), f(a,b)(a==b))", + assertTranslated("join(tensor(x{}):{1:1.0, 3:5.0, 9:3.0}, reduce(tensor(x{}):{1:1.0, 3:5.0, 9:3.0}, max, x), f(a,b)(a==b))", new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); - assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max), f(a,b)(a==b))", + assertTranslated("join(tensor(x{}):{1:1.0, 3:5.0, 9:3.0}, reduce(tensor(x{}):{1:1.0, 3:5.0, 9:3.0}, max), f(a,b)(a==b))", new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"))); } |