diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-10 11:39:39 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-10 11:39:39 -0800 |
commit | 4c46e1816d2cdfacd8435ad4d55e831929fc99ba (patch) | |
tree | d55a90aeeddcf9265a74e7f16129517e36f45375 /vespajlib/src/test/java/com | |
parent | b8d2859a9fece15dac2b9260d71dea39f8ce19b3 (diff) |
Tensor parsing improvements
- Mixed tensor format parsing (outside expressions)
- Validate structure of dense tensor strings
Diffstat (limited to 'vespajlib/src/test/java/com')
3 files changed, 44 insertions, 20 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index 1928971820c..b2aba5b02eb 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -22,6 +22,12 @@ public class TensorParserTestCase { } @Test + public void testSingle() { + assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(), + "tensor(x[1]):[1.0]"); + } + + @Test public void testDenseParsing() { assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(), "tensor():{0.0}"); @@ -55,18 +61,9 @@ public class TensorParserTestCase { .cell(3.0, 1, 0, 0) .cell(4.0, 1, 1, 0) .cell(5.0, 2, 0, 0) - .cell(6.0, 2, 1, 0).build(), - "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [6.0]]]"); - assertEquals("Messy input", - Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])")) - .cell( 1.0, 0, 0, 0) - .cell( 2.0, 0, 1, 0) - .cell( 3.0, 1, 0, 0) - .cell( 4.0, 1, 1, 0) - .cell( 5.0, 2, 0, 0) .cell(-6.0, 2, 1, 0).build(), - Tensor.from("tensor( x[3],y[2],z[1]) : [ [ [1.0, 2.0, 3.0] , [4.0, 5,-6.0] ] ]")); - assertEquals("Skipping syntactic sugar", + "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [-6.0]]]"); + assertEquals("Skipping structure", Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])")) .cell( 1.0, 0, 0, 0) .cell( 2.0, 0, 1, 0) @@ -77,6 +74,16 @@ public class TensorParserTestCase { Tensor.from("tensor( x[3],y[2],z[1]) : [1.0, 2.0, 3.0 , 4.0, 5, -6.0]")); } + @Test + public void testMixedParsing() { + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])")) + .cell(TensorAddress.ofLabels("a", "0"), 1) + .cell(TensorAddress.ofLabels("a", "1"), 2) + .cell(TensorAddress.ofLabels("b", "0"), 3) + .cell(TensorAddress.ofLabels("b", "1"), 4).build(), + Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}")); + } + private void assertDense(Tensor expectedTensor, String denseFormat) { assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat)); assertEquals(denseFormat, expectedTensor.toString()); @@ -92,7 +99,7 @@ public class TensorParserTestCase { "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}"); assertIllegal("At {x:0}: '1-.0' is not a valid double", "{{x:0}:1-.0}"); - assertIllegal("At index 0: '1-.0' is not a valid double", + assertIllegal("At position 1: '1-.0' is not a valid double", "tensor(x[1]):[1-.0]"); } @@ -102,7 +109,7 @@ public class TensorParserTestCase { fail("Expected an IllegalArgumentException when parsing " + tensor); } catch (IllegalArgumentException e) { - assertEquals(message, e.getMessage()); + assertEquals(message, e.getCause().getMessage()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 11365531019..9f077cb7b00 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -56,7 +56,8 @@ public class TensorTestCase { fail("Expected parse error"); } catch (IllegalArgumentException expected) { - assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'", expected.getMessage()); + assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'", + expected.getCause().getMessage()); } } @@ -259,9 +260,9 @@ public class TensorTestCase { assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0", "tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l3}:5.0,{d1:l1,d2:l2}:6.0}"); assertLargest("{x:1,y:1}:4.0", - "tensor(x[2],y[2]):[[1,2],[3,4]"); + "tensor(x[2],y[2]):[[1,2],[3,4]]"); assertLargest("{x:0,y:0}:4.0, {x:1,y:1}:4.0", - "tensor(x[2],y[2]):[[4,2],[3,4]"); + "tensor(x[2],y[2]):[[4,2],[3,4]]"); } @Test @@ -273,9 +274,9 @@ public class TensorTestCase { assertSmallest("{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:5.0", "tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l3}:6.0,{d1:l1,d2:l2}:5.0}"); assertSmallest("{x:0,y:0}:1.0", - "tensor(x[2],y[2]):[[1,2],[3,4]"); + "tensor(x[2],y[2]):[[1,2],[3,4]]"); assertSmallest("{x:0,y:1}:2.0", - "tensor(x[2],y[2]):[[4,2],[3,4]"); + "tensor(x[2],y[2]):[[4,2],[3,4]]"); } private void assertLargest(String expectedCells, String tensorString) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index e16b7b90a1d..7cddeab1641 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.Name; import org.junit.Test; import java.util.Collections; +import java.util.HashMap; import java.util.List; import static org.junit.Assert.assertEquals; @@ -19,21 +20,36 @@ import static org.junit.Assert.assertEquals; public class DynamicTensorTestCase { @Test - public void testDynamicTensorFunction() { + public void testDynamicIndexedRank1TensorFunction() { TensorType dense = TensorType.fromSpec("tensor(x[3])"); DynamicTensor<Name> t1 = DynamicTensor.from(dense, List.of(new Constant(1), new Constant(2), new Constant(3))); assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate()); assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString()); + } + @Test + public void testDynamicMappedRank1TensorFunction() { TensorType sparse = TensorType.fromSpec("tensor(x{})"); DynamicTensor<Name> t2 = DynamicTensor.from(sparse, Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), - new Constant(5))); + new Constant(5))); assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); } + @Test + public void testDynamicMappedRank2TensorFunction() { + TensorType sparse = TensorType.fromSpec("tensor(x{},y{})"); + HashMap<TensorAddress, ScalarFunction<Name>> values = new HashMap<>(); + values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "b").build(), + new Constant(5)); + values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "c").build(), + new Constant(7)); + DynamicTensor<Name> t2 = DynamicTensor.from(sparse, values); + assertEquals(Tensor.from(sparse, "{{x:a,y:b}:5, {x:a,y:c}:7}"), t2.evaluate()); + } + private static class Constant implements ScalarFunction<Name> { private final double value; |