summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-10 11:39:39 -0800
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-10 11:39:39 -0800
commit4c46e1816d2cdfacd8435ad4d55e831929fc99ba (patch)
treed55a90aeeddcf9265a74e7f16129517e36f45375 /vespajlib/src/test/java/com/yahoo/tensor
parentb8d2859a9fece15dac2b9260d71dea39f8ce19b3 (diff)
Tensor parsing improvements
- Mixed tensor format parsing (outside expressions) - Validate structure of dense tensor strings
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java33
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java20
3 files changed, 44 insertions, 20 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
index 1928971820c..b2aba5b02eb 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
@@ -22,6 +22,12 @@ public class TensorParserTestCase {
}
@Test
+ public void testSingle() {
+ assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(),
+ "tensor(x[1]):[1.0]");
+ }
+
+ @Test
public void testDenseParsing() {
assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(),
"tensor():{0.0}");
@@ -55,18 +61,9 @@ public class TensorParserTestCase {
.cell(3.0, 1, 0, 0)
.cell(4.0, 1, 1, 0)
.cell(5.0, 2, 0, 0)
- .cell(6.0, 2, 1, 0).build(),
- "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [6.0]]]");
- assertEquals("Messy input",
- Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])"))
- .cell( 1.0, 0, 0, 0)
- .cell( 2.0, 0, 1, 0)
- .cell( 3.0, 1, 0, 0)
- .cell( 4.0, 1, 1, 0)
- .cell( 5.0, 2, 0, 0)
.cell(-6.0, 2, 1, 0).build(),
- Tensor.from("tensor( x[3],y[2],z[1]) : [ [ [1.0, 2.0, 3.0] , [4.0, 5,-6.0] ] ]"));
- assertEquals("Skipping syntactic sugar",
+ "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [-6.0]]]");
+ assertEquals("Skipping structure",
Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])"))
.cell( 1.0, 0, 0, 0)
.cell( 2.0, 0, 1, 0)
@@ -77,6 +74,16 @@ public class TensorParserTestCase {
Tensor.from("tensor( x[3],y[2],z[1]) : [1.0, 2.0, 3.0 , 4.0, 5, -6.0]"));
}
+ @Test
+ public void testMixedParsing() {
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])"))
+ .cell(TensorAddress.ofLabels("a", "0"), 1)
+ .cell(TensorAddress.ofLabels("a", "1"), 2)
+ .cell(TensorAddress.ofLabels("b", "0"), 3)
+ .cell(TensorAddress.ofLabels("b", "1"), 4).build(),
+ Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}"));
+ }
+
private void assertDense(Tensor expectedTensor, String denseFormat) {
assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat));
assertEquals(denseFormat, expectedTensor.toString());
@@ -92,7 +99,7 @@ public class TensorParserTestCase {
"{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}");
assertIllegal("At {x:0}: '1-.0' is not a valid double",
"{{x:0}:1-.0}");
- assertIllegal("At index 0: '1-.0' is not a valid double",
+ assertIllegal("At position 1: '1-.0' is not a valid double",
"tensor(x[1]):[1-.0]");
}
@@ -102,7 +109,7 @@ public class TensorParserTestCase {
fail("Expected an IllegalArgumentException when parsing " + tensor);
}
catch (IllegalArgumentException e) {
- assertEquals(message, e.getMessage());
+ assertEquals(message, e.getCause().getMessage());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 11365531019..9f077cb7b00 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -56,7 +56,8 @@ public class TensorTestCase {
fail("Expected parse error");
}
catch (IllegalArgumentException expected) {
- assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'", expected.getMessage());
+ assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'",
+ expected.getCause().getMessage());
}
}
@@ -259,9 +260,9 @@ public class TensorTestCase {
assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0",
"tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l3}:5.0,{d1:l1,d2:l2}:6.0}");
assertLargest("{x:1,y:1}:4.0",
- "tensor(x[2],y[2]):[[1,2],[3,4]");
+ "tensor(x[2],y[2]):[[1,2],[3,4]]");
assertLargest("{x:0,y:0}:4.0, {x:1,y:1}:4.0",
- "tensor(x[2],y[2]):[[4,2],[3,4]");
+ "tensor(x[2],y[2]):[[4,2],[3,4]]");
}
@Test
@@ -273,9 +274,9 @@ public class TensorTestCase {
assertSmallest("{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:5.0",
"tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l3}:6.0,{d1:l1,d2:l2}:5.0}");
assertSmallest("{x:0,y:0}:1.0",
- "tensor(x[2],y[2]):[[1,2],[3,4]");
+ "tensor(x[2],y[2]):[[1,2],[3,4]]");
assertSmallest("{x:0,y:1}:2.0",
- "tensor(x[2],y[2]):[[4,2],[3,4]");
+ "tensor(x[2],y[2]):[[4,2],[3,4]]");
}
private void assertLargest(String expectedCells, String tensorString) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index e16b7b90a1d..7cddeab1641 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.Name;
import org.junit.Test;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
import static org.junit.Assert.assertEquals;
@@ -19,21 +20,36 @@ import static org.junit.Assert.assertEquals;
public class DynamicTensorTestCase {
@Test
- public void testDynamicTensorFunction() {
+ public void testDynamicIndexedRank1TensorFunction() {
TensorType dense = TensorType.fromSpec("tensor(x[3])");
DynamicTensor<Name> t1 = DynamicTensor.from(dense,
List.of(new Constant(1), new Constant(2), new Constant(3)));
assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate());
assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString());
+ }
+ @Test
+ public void testDynamicMappedRank1TensorFunction() {
TensorType sparse = TensorType.fromSpec("tensor(x{})");
DynamicTensor<Name> t2 = DynamicTensor.from(sparse,
Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
- new Constant(5)));
+ new Constant(5)));
assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate());
assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
}
+ @Test
+ public void testDynamicMappedRank2TensorFunction() {
+ TensorType sparse = TensorType.fromSpec("tensor(x{},y{})");
+ HashMap<TensorAddress, ScalarFunction<Name>> values = new HashMap<>();
+ values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "b").build(),
+ new Constant(5));
+ values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "c").build(),
+ new Constant(7));
+ DynamicTensor<Name> t2 = DynamicTensor.from(sparse, values);
+ assertEquals(Tensor.from(sparse, "{{x:a,y:b}:5, {x:a,y:c}:7}"), t2.evaluate());
+ }
+
private static class Constant implements ScalarFunction<Name> {
private final double value;