diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-07-04 16:04:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-04 16:04:33 +0200 |
commit | f57ba3d13a46a4260c4155271853dd228a6ebf3e (patch) | |
tree | c6cd3224625125a4724c41a6e4bcf4da90f2995d /config-model/src/test | |
parent | 3d820924a0a7c079df064991e5acc5240149220f (diff) | |
parent | f7102f53e8fce82589593fcc06323085a5940681 (diff) |
Merge pull request #27616 from vespa-engine/arnej/handle-more-tensor-formats
Arnej/handle more tensor formats
Diffstat (limited to 'config-model/src/test')
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java | 86 |
1 files changed, 81 insertions, 5 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java index 42be1592eca..747315c1fdf 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java @@ -19,7 +19,7 @@ public class ConstantTensorJsonValidatorTest { } private static void validateTensorJson(TensorType tensorType, Reader jsonTensorReader) { - new ConstantTensorJsonValidator().validate("dummy.json", tensorType, jsonTensorReader); + new ConstantTensorJsonValidator(tensorType).validate("dummy.json", jsonTensorReader); } @Test @@ -207,8 +207,8 @@ public class ConstantTensorJsonValidatorTest { " }", " ]", "}")); - }); - assertTrue(exception.getMessage().contains("Tensor value is not a number (VALUE_STRING)")); + }); + assertTrue(exception.getMessage().contains("Inside 'value': cell value is not a number (VALUE_STRING)")); } @Test @@ -281,8 +281,7 @@ public class ConstantTensorJsonValidatorTest { " }", "}")); }); - System.err.println("msg: " + exception.getMessage()); - assertTrue(exception.getMessage().contains("Expected 'cells' or 'values', got 'stats'")); + assertTrue(exception.getMessage().contains("Unexpected content '{' for field 'stats'")); } @Test @@ -302,4 +301,81 @@ public class ConstantTensorJsonValidatorTest { inputJsonToReader("{'cells':{'a':5,'b':4.0,'c':3.1,'d':-2,'e':-1.0}}")); } + @Test + void ensure_that_matrices_work() { + validateTensorJson( + TensorType.fromSpec("tensor(x[2], y[3])"), + inputJsonToReader( + "[", + " [ 1, 2, 3],", + " [ 4, 5, 6]", + "]")); + validateTensorJson( + TensorType.fromSpec("tensor(x[2], y[3])"), + inputJsonToReader( + "{'values':[", + " [ 1, 2, 3],", + " [ 4, 5, 6]", + "]}")); + } + + @Test + void ensure_that_simple_maps_work() { + validateTensorJson( + TensorType.fromSpec("tensor(category{})"), + inputJsonToReader( + "{", + " 'foo': 1,", + " 'bar': 2,", + " 'type': 3,", + " 'cells': 4,", + " 'value': 5,", + " 'values': 6,", + " 'blocks': 7,", + " 'anything': 8", + "}")); + validateTensorJson( + TensorType.fromSpec("tensor(category{})"), + inputJsonToReader( + "{'cells':{", + " 'foo': 1,", + " 'bar': 2,", + " 'type': 3,", + " 'cells': 4,", + " 'value': 5,", + " 'values': 6,", + " 'blocks': 7,", + " 'anything': 8", + "}}")); + } + + @Test + void ensure_that_mixing_formats_disallowed() { + Throwable exception = assertThrows(InvalidConstantTensorException.class, () -> { + validateTensorJson( + TensorType.fromSpec("tensor(x{})"), + inputJsonToReader("{ 'a': 1.0, 'cells': { 'b': 2.0 } }")); + + }); + assertTrue(exception.getMessage().contains("Cannot use {label: value} format together with 'cells'")); + } + + @Test + void ensure_that_simple_blocks_work() { + validateTensorJson( + TensorType.fromSpec("tensor(a{},b[3])"), + inputJsonToReader( + "{'blocks':{'foo':[1,2,3], 'bar':[4,5,6]}}")); + } + + @Test + void ensure_that_complex_blocks_work() { + validateTensorJson( + TensorType.fromSpec("tensor(a{},b[3],c{},d[2])"), + inputJsonToReader( + "{'blocks':[", + "{'address':{'a':'foo','c':'bar'},'values':[[1,2],[3,4],[5,6]]},", + "{'address':{'a':'qux','c':'zip'},'values':[[9,8],[7,6],[5,4]]}]}")); + } + } |