aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-07-04 16:04:33 +0200
committerGitHub <noreply@github.com>2023-07-04 16:04:33 +0200
commitf57ba3d13a46a4260c4155271853dd228a6ebf3e (patch)
treec6cd3224625125a4724c41a6e4bcf4da90f2995d /config-model/src/test/java
parent3d820924a0a7c079df064991e5acc5240149220f (diff)
parentf7102f53e8fce82589593fcc06323085a5940681 (diff)
Merge pull request #27616 from vespa-engine/arnej/handle-more-tensor-formats
Arnej/handle more tensor formats
Diffstat (limited to 'config-model/src/test/java')
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java86
1 files changed, 81 insertions, 5 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
index 42be1592eca..747315c1fdf 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
@@ -19,7 +19,7 @@ public class ConstantTensorJsonValidatorTest {
}
private static void validateTensorJson(TensorType tensorType, Reader jsonTensorReader) {
- new ConstantTensorJsonValidator().validate("dummy.json", tensorType, jsonTensorReader);
+ new ConstantTensorJsonValidator(tensorType).validate("dummy.json", jsonTensorReader);
}
@Test
@@ -207,8 +207,8 @@ public class ConstantTensorJsonValidatorTest {
" }",
" ]",
"}"));
- });
- assertTrue(exception.getMessage().contains("Tensor value is not a number (VALUE_STRING)"));
+ });
+ assertTrue(exception.getMessage().contains("Inside 'value': cell value is not a number (VALUE_STRING)"));
}
@Test
@@ -281,8 +281,7 @@ public class ConstantTensorJsonValidatorTest {
" }",
"}"));
});
- System.err.println("msg: " + exception.getMessage());
- assertTrue(exception.getMessage().contains("Expected 'cells' or 'values', got 'stats'"));
+ assertTrue(exception.getMessage().contains("Unexpected content '{' for field 'stats'"));
}
@Test
@@ -302,4 +301,81 @@ public class ConstantTensorJsonValidatorTest {
inputJsonToReader("{'cells':{'a':5,'b':4.0,'c':3.1,'d':-2,'e':-1.0}}"));
}
+ @Test
+ void ensure_that_matrices_work() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x[2], y[3])"),
+ inputJsonToReader(
+ "[",
+ " [ 1, 2, 3],",
+ " [ 4, 5, 6]",
+ "]"));
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x[2], y[3])"),
+ inputJsonToReader(
+ "{'values':[",
+ " [ 1, 2, 3],",
+ " [ 4, 5, 6]",
+ "]}"));
+ }
+
+ @Test
+ void ensure_that_simple_maps_work() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(category{})"),
+ inputJsonToReader(
+ "{",
+ " 'foo': 1,",
+ " 'bar': 2,",
+ " 'type': 3,",
+ " 'cells': 4,",
+ " 'value': 5,",
+ " 'values': 6,",
+ " 'blocks': 7,",
+ " 'anything': 8",
+ "}"));
+ validateTensorJson(
+ TensorType.fromSpec("tensor(category{})"),
+ inputJsonToReader(
+ "{'cells':{",
+ " 'foo': 1,",
+ " 'bar': 2,",
+ " 'type': 3,",
+ " 'cells': 4,",
+ " 'value': 5,",
+ " 'values': 6,",
+ " 'blocks': 7,",
+ " 'anything': 8",
+ "}}"));
+ }
+
+ @Test
+ void ensure_that_mixing_formats_disallowed() {
+ Throwable exception = assertThrows(InvalidConstantTensorException.class, () -> {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x{})"),
+ inputJsonToReader("{ 'a': 1.0, 'cells': { 'b': 2.0 } }"));
+
+ });
+ assertTrue(exception.getMessage().contains("Cannot use {label: value} format together with 'cells'"));
+ }
+
+ @Test
+ void ensure_that_simple_blocks_work() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(a{},b[3])"),
+ inputJsonToReader(
+ "{'blocks':{'foo':[1,2,3], 'bar':[4,5,6]}}"));
+ }
+
+ @Test
+ void ensure_that_complex_blocks_work() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(a{},b[3],c{},d[2])"),
+ inputJsonToReader(
+ "{'blocks':[",
+ "{'address':{'a':'foo','c':'bar'},'values':[[1,2],[3,4],[5,6]]},",
+ "{'address':{'a':'qux','c':'zip'},'values':[[9,8],[7,6],[5,4]]}]}"));
+ }
+
}