summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-07-03 12:40:14 +0000
committerArne Juul <arnej@yahooinc.com>2023-07-03 13:09:11 +0000
commitf7102f53e8fce82589593fcc06323085a5940681 (patch)
tree9f8c27bc48e35f0e8dde977cfda32a6434139e35
parent619aa88c3a8b49a1c2ea84afce9e59dc90ed75a9 (diff)
extend ConstantTensorJsonValidator.java
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java396
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java4
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java86
3 files changed, 376 insertions, 110 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
index eccb6910866..df3cd4103d9 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
@@ -11,6 +11,7 @@ import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -22,6 +23,7 @@ import java.util.stream.Collectors;
* ConstantTensorJsonValidator strictly validates a constant tensor in JSON format read from a Reader object
*
* @author Vegard Sjonfjell
+ * @author arnej
*/
public class ConstantTensorJsonValidator {
@@ -29,17 +31,63 @@ public class ConstantTensorJsonValidator {
private static final String FIELD_ADDRESS = "address";
private static final String FIELD_VALUE = "value";
private static final String FIELD_VALUES = "values";
+ private static final String FIELD_BLOCKS = "blocks";
+ private static final String FIELD_TYPE = "type";
private static final JsonFactory jsonFactory = new JsonFactory();
private JsonParser parser;
- private Map<String, TensorType.Dimension> tensorDimensions;
- private boolean isSingleDenseType = false;
- private boolean isSingleMappedType = false;
+ private final TensorType tensorType;
+ private final Map<String, TensorType.Dimension> tensorDimensions = new HashMap<>();
+ private final List<String> denseDims = new ArrayList<>();
+ private final List<String> mappedDims = new ArrayList<>();
+ private int numIndexedDims = 0;
+ private int numMappedDims = 0;
+ private boolean seenCells = false;
+ private boolean seenValues = false;
+ private boolean seenBlocks = false;
+ private boolean seenType = false;
+ private boolean seenSimpleMapValue = false;
+
+ private boolean isScalar() {
+ return (numIndexedDims == 0 && numMappedDims == 0);
+ }
+ private boolean isDense() {
+ return (numIndexedDims > 0 && numMappedDims == 0);
+ }
+ private boolean isSparse() {
+ return (numIndexedDims == 0 && numMappedDims > 0);
+ }
+ private boolean isSingleDense() {
+ return (numIndexedDims == 1 && numMappedDims == 0);
+ }
+ private boolean isSingleSparse() {
+ return (numIndexedDims == 0 && numMappedDims == 1);
+ }
+ private boolean isMixed() {
+ return (numIndexedDims > 0 && numMappedDims > 0);
+ }
- public void validate(String fileName, TensorType type, Reader tensorData) {
+ public ConstantTensorJsonValidator(TensorType type) {
+ this.tensorType = type;
+ for (var dim : type.dimensions()) {
+ tensorDimensions.put(dim.name(), dim);
+ switch (dim.type()) {
+ case mapped:
+ ++numMappedDims;
+ mappedDims.add(dim.name());
+ break;
+ case indexedBound:
+ case indexedUnbound:
+ ++numIndexedDims;
+ denseDims.add(dim.name());
+ }
+ }
+ }
+
+ public void validate(String fileName, Reader tensorData) {
if (fileName.endsWith(".json")) {
- validateTensor(type, tensorData);
+ validateTensor(tensorData);
}
else if (fileName.endsWith(".json.lz4")) {
// don't validate; the cost probably outweights the advantage
@@ -53,127 +101,156 @@ public class ConstantTensorJsonValidator {
}
}
- private void validateTensor(TensorType type, Reader tensorData) {
- wrapIOException(() -> {
+ private void validateTensor(Reader tensorData) {
+ try {
this.parser = jsonFactory.createParser(tensorData);
- this.tensorDimensions = type
- .dimensions()
- .stream()
- .collect(Collectors.toMap(TensorType.Dimension::name, Function.identity()));
- if (type.dimensions().size() == 1) {
- this.isSingleMappedType = (type.indexedSubtype() == TensorType.empty);
- this.isSingleDenseType = (type.mappedSubtype() == TensorType.empty);
- }
var top = parser.nextToken();
- if (top == JsonToken.START_ARRAY) {
+ if (top == JsonToken.START_ARRAY && isDense()) {
consumeValuesArray();
+ return;
} else if (top == JsonToken.START_OBJECT) {
consumeTopObject();
+ return;
+ } else if (isScalar()) {
+ if (top == JsonToken.VALUE_NUMBER_FLOAT || top == JsonToken.VALUE_NUMBER_INT) {
+ return;
+ }
+ }
+ throw new InvalidConstantTensorException(
+ parser, String.format("Unexpected first token '%s' for constant with type %s",
+ parser.getText(), tensorType.toString()));
+ } catch (IOException e) {
+ if (parser != null) {
+ throw new InvalidConstantTensorException(parser, e);
}
- });
+ throw new InvalidConstantTensorException(e);
+ }
}
private void consumeValuesArray() throws IOException {
- if (! isSingleDenseType) {
- throw new InvalidConstantTensorException(parser, String.format("Field 'values' is only valid for simple vectors (1-d dense tensors"));
- }
- assertCurrentTokenIs(JsonToken.START_ARRAY);
- while (parser.nextToken() != JsonToken.END_ARRAY) {
- validateNumeric(parser.getCurrentToken());
- }
+ consumeValuesNesting(0);
}
private void consumeTopObject() throws IOException {
- assertCurrentTokenIs(JsonToken.START_OBJECT);
- assertNextTokenIs(JsonToken.FIELD_NAME);
- String fieldName = parser.getCurrentName();
- if (fieldName.equals(FIELD_VALUES)) {
- assertNextTokenIs(JsonToken.START_ARRAY);
- consumeValuesArray();
- } else if (fieldName.equals(FIELD_CELLS)) {
- consumeCellsField();
- } else {
- throw new InvalidConstantTensorException(parser, String.format("Expected 'cells' or 'values', got '%s'", fieldName));
+ for (var cur = parser.nextToken(); cur != JsonToken.END_OBJECT; cur = parser.nextToken()) {
+ assertCurrentTokenIs(JsonToken.FIELD_NAME);
+ String fieldName = parser.getCurrentName();
+ switch (fieldName) {
+ case FIELD_TYPE:
+ consumeTypeField();
+ break;
+ case FIELD_VALUES:
+ consumeValuesField();
+ break;
+ case FIELD_CELLS:
+ consumeCellsField();
+ break;
+ case FIELD_BLOCKS:
+ consumeBlocksField();
+ break;
+ default:
+ consumeAnyField(fieldName, parser.nextToken());
+ break;
+ }
+ }
+ if (seenSimpleMapValue) {
+ if (! isSingleSparse()) {
+ throw new InvalidConstantTensorException(parser, String.format("Cannot use {label: value} format for constant of type %s", tensorType.toString()));
+ }
+ if (seenCells || seenValues || seenBlocks || seenType) {
+ throw new InvalidConstantTensorException(parser, String.format("Cannot use {label: value} format together with '%s'",
+ (seenCells ? FIELD_CELLS :
+ (seenValues ? FIELD_VALUES :
+ (seenBlocks ? FIELD_BLOCKS : FIELD_TYPE)))));
+ }
+ }
+ if (seenCells) {
+ if (seenValues || seenBlocks) {
+ throw new InvalidConstantTensorException(parser, String.format("Cannot use both '%s' and '%s' at the same time",
+ FIELD_CELLS, (seenValues ? FIELD_VALUES : FIELD_BLOCKS)));
+ }
+ }
+ if (seenValues && seenBlocks) {
+ throw new InvalidConstantTensorException(parser, String.format("Cannot use both '%s' and '%s' at the same time",
+ FIELD_VALUES, FIELD_BLOCKS));
}
- assertNextTokenIs(JsonToken.END_OBJECT);
}
private void consumeCellsField() throws IOException {
- var token = parser.nextToken();
- if (token == JsonToken.START_ARRAY) {
+ var cur = parser.nextToken();
+ if (cur == JsonToken.START_ARRAY) {
consumeLiteralFormArray();
- } else if (token == JsonToken.START_OBJECT) {
+ seenCells = true;
+ } else if (cur == JsonToken.START_OBJECT) {
consumeSimpleMappedObject();
+ seenCells = true;
} else {
- throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be object or array, but got %s", token.toString()));
+ consumeAnyField(FIELD_BLOCKS, cur);
}
}
private void consumeLiteralFormArray() throws IOException {
while (parser.nextToken() != JsonToken.END_ARRAY) {
- validateTensorCell();
+ validateLiteralFormCell();
}
}
private void consumeSimpleMappedObject() throws IOException {
- if (! isSingleMappedType) {
- throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be an array of address/value objects"));
+ if (! isSingleSparse()) {
+ throw new InvalidConstantTensorException(parser, String.format("Cannot use {label: value} format for constant of type %s", tensorType.toString()));
}
- while (parser.nextToken() != JsonToken.END_OBJECT) {
+ for (var cur = parser.nextToken(); cur != JsonToken.END_OBJECT; cur = parser.nextToken()) {
assertCurrentTokenIs(JsonToken.FIELD_NAME);
- validateTensorCellValue();
+ validateNumeric(parser.getCurrentName(), parser.nextToken());
}
}
- private void validateTensorCell() {
- wrapIOException(() -> {
- assertCurrentTokenIs(JsonToken.START_OBJECT);
-
- List<String> fieldNameCandidates = new ArrayList<>(Arrays.asList(FIELD_ADDRESS, FIELD_VALUE));
- for (int i = 0; i < 2; i++) {
- assertNextTokenIs(JsonToken.FIELD_NAME);
- String fieldName = parser.getCurrentName();
-
- if (fieldNameCandidates.contains(fieldName)) {
- fieldNameCandidates.remove(fieldName);
-
- if (fieldName.equals(FIELD_ADDRESS)) {
- validateTensorAddress();
- } else if (fieldName.equals(FIELD_VALUE)) {
- validateTensorCellValue();
- }
- } else {
- throw new InvalidConstantTensorException(parser, "Only 'address' or 'value' fields are permitted within a cell object");
- }
+ private void validateLiteralFormCell() throws IOException {
+ assertCurrentTokenIs(JsonToken.START_OBJECT);
+ boolean seenAddress = false;
+ boolean seenValue = false;
+ for (int i = 0; i < 2; i++) {
+ assertNextTokenIs(JsonToken.FIELD_NAME);
+ String fieldName = parser.getCurrentName();
+ switch (fieldName) {
+ case FIELD_ADDRESS:
+ validateTensorAddress(new HashSet<>(tensorDimensions.keySet()));
+ seenAddress = true;
+ break;
+ case FIELD_VALUE:
+ validateNumeric(FIELD_VALUE, parser.nextToken());
+ seenValue = true;
+ break;
+ default:
+ throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a cell object",
+ FIELD_ADDRESS, FIELD_VALUE));
}
-
- assertNextTokenIs(JsonToken.END_OBJECT);
- });
+ }
+ if (! seenAddress) {
+ throw new InvalidConstantTensorException(parser, String.format("Missing '%s' field in cell object", FIELD_ADDRESS));
+ }
+ if (! seenValue) {
+ throw new InvalidConstantTensorException(parser, String.format("Missing '%s' field in cell object", FIELD_VALUE));
+ }
+ assertNextTokenIs(JsonToken.END_OBJECT);
}
- private void validateTensorAddress() throws IOException {
+ private void validateTensorAddress(Set<String> cellDimensions) throws IOException {
assertNextTokenIs(JsonToken.START_OBJECT);
-
- Set<String> cellDimensions = new HashSet<>(tensorDimensions.keySet());
-
// Iterate within the address key, value pairs
while ((parser.nextToken() != JsonToken.END_OBJECT)) {
assertCurrentTokenIs(JsonToken.FIELD_NAME);
-
String dimensionName = parser.getCurrentName();
TensorType.Dimension dimension = tensorDimensions.get(dimensionName);
if (dimension == null) {
throw new InvalidConstantTensorException(parser, String.format("Tensor dimension '%s' does not exist", parser.getCurrentName()));
}
-
if (!cellDimensions.contains(dimensionName)) {
throw new InvalidConstantTensorException(parser, String.format("Duplicate tensor dimension '%s'", parser.getCurrentName()));
}
-
cellDimensions.remove(dimensionName);
validateLabel(dimension);
}
-
if (!cellDimensions.isEmpty()) {
throw new InvalidConstantTensorException(parser, String.format("Tensor address missing dimension(s) %s", Joiner.on(", ").join(cellDimensions)));
}
@@ -186,9 +263,9 @@ public class ConstantTensorJsonValidator {
*/
private void validateLabel(TensorType.Dimension dimension) throws IOException {
JsonToken token = parser.nextToken();
- if (token != JsonToken.VALUE_STRING)
+ if (token != JsonToken.VALUE_STRING) {
throw new InvalidConstantTensorException(parser, String.format("Tensor label is not a string (%s)", token.toString()));
-
+ }
if (dimension instanceof TensorType.IndexedBoundDimension) {
validateBoundIndex((TensorType.IndexedBoundDimension) dimension);
} else if (dimension instanceof TensorType.IndexedUnboundDimension) {
@@ -196,40 +273,31 @@ public class ConstantTensorJsonValidator {
}
}
- private void validateBoundIndex(TensorType.IndexedBoundDimension dimension) {
- wrapIOException(() -> {
- try {
- int value = Integer.parseInt(parser.getValueAsString());
- if (value >= dimension.size().get())
- throw new InvalidConstantTensorException(parser, String.format("Index %s not within limits of bound dimension '%s'", value, dimension.name()));
- } catch (NumberFormatException e) {
- throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name());
- }
- });
+ private void validateBoundIndex(TensorType.IndexedBoundDimension dimension) throws IOException {
+ try {
+ int value = Integer.parseInt(parser.getValueAsString());
+ if (value >= dimension.size().get())
+ throw new InvalidConstantTensorException(parser, String.format("Index %s not within limits of bound dimension '%s'", value, dimension.name()));
+ } catch (NumberFormatException e) {
+ throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name());
+ }
}
- private void validateUnboundIndex(TensorType.Dimension dimension) {
- wrapIOException(() -> {
- try {
- Integer.parseInt(parser.getValueAsString());
- } catch (NumberFormatException e) {
- throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name());
- }
- });
+ private void validateUnboundIndex(TensorType.Dimension dimension) throws IOException {
+ try {
+ Integer.parseInt(parser.getValueAsString());
+ } catch (NumberFormatException e) {
+ throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name());
+ }
}
private void throwCoordinateIsNotInteger(String value, String dimensionName) {
throw new InvalidConstantTensorException(parser, String.format("Index '%s' for dimension '%s' is not an integer", value, dimensionName));
}
- private void validateTensorCellValue() throws IOException {
- JsonToken token = parser.nextToken();
- validateNumeric(token);
- }
-
- private void validateNumeric(JsonToken token) throws IOException {
+ private void validateNumeric(String where, JsonToken token) throws IOException {
if (token != JsonToken.VALUE_NUMBER_FLOAT && token != JsonToken.VALUE_NUMBER_INT) {
- throw new InvalidConstantTensorException(parser, String.format("Tensor value is not a number (%s)", token.toString()));
+ throw new InvalidConstantTensorException(parser, String.format("Inside '%s': cell value is not a number (%s)", where, token.toString()));
}
}
@@ -265,6 +333,9 @@ public class ConstantTensorJsonValidator {
super("Failed to parse JSON stream " + parser.getCurrentLocation().toString(), base);
}
+ InvalidConstantTensorException(IOException base) {
+ super("Failed to parse JSON stream: " + base.getMessage(), base);
+ }
}
@FunctionalInterface
@@ -280,4 +351,125 @@ public class ConstantTensorJsonValidator {
}
}
+ private void consumeValuesNesting(int level) throws IOException {
+ assertCurrentTokenIs(JsonToken.START_ARRAY);
+ if (level >= denseDims.size()) {
+ throw new InvalidConstantTensorException(
+ parser, String.format("Too deep array nesting for constant with type %s", tensorType.toString()));
+ }
+ var dim = tensorDimensions.get(denseDims.get(level));
+ long count = 0;
+ for (var cur = parser.nextToken(); cur != JsonToken.END_ARRAY; cur = parser.nextToken()) {
+ if (level + 1 == denseDims.size()) {
+ validateNumeric(FIELD_VALUES, cur);
+ } else if (cur == JsonToken.START_ARRAY) {
+ consumeValuesNesting(level + 1);
+ } else {
+ throw new InvalidConstantTensorException(
+ parser, String.format("Unexpected token %s '%s'", cur.toString(), parser.getText()));
+ }
+ ++count;
+ }
+ if (dim.size().isPresent()) {
+ var sz = dim.size().get();
+ if (sz != count) {
+ throw new InvalidConstantTensorException(
+ parser, String.format("Dimension '%s' has size %d but array had %d values", dim.name(), sz, count));
+ }
+ }
+ }
+
+ private void consumeTypeField() throws IOException {
+ var cur = parser.nextToken();
+ if (cur == JsonToken.VALUE_STRING) {
+ seenType = true;
+ } else if (isSingleSparse()) {
+ validateNumeric(FIELD_TYPE, cur);
+ seenSimpleMapValue = true;
+ } else {
+ throw new InvalidConstantTensorException(
+ parser, String.format("Field '%s' should contain the tensor type as a string, got %s", FIELD_TYPE, parser.getText()));
+ }
+ }
+
+ private void consumeValuesField() throws IOException {
+ var cur = parser.nextToken();
+ if (isDense() && cur == JsonToken.START_ARRAY) {
+ consumeValuesArray();
+ seenValues = true;
+ } else {
+ consumeAnyField(FIELD_VALUES, cur);
+ }
+ }
+
+ private void consumeBlocksField() throws IOException {
+ var cur = parser.nextToken();
+ if (cur == JsonToken.START_ARRAY) {
+ consumeBlocksArray();
+ seenBlocks = true;
+ } else if (cur == JsonToken.START_OBJECT) {
+ consumeBlocksObject();
+ seenBlocks = true;
+ } else {
+ consumeAnyField(FIELD_BLOCKS, cur);
+ }
+ }
+
+ private void consumeAnyField(String fieldName, JsonToken cur) throws IOException {
+ if (isSingleSparse()) {
+ validateNumeric(FIELD_CELLS, cur);
+ seenSimpleMapValue = true;
+ } else {
+ throw new InvalidConstantTensorException(
+ parser, String.format("Unexpected content '%s' for field '%s'", parser.getText(), fieldName));
+ }
+ }
+
+ private void consumeBlocksArray() throws IOException {
+ if (! isMixed()) {
+ throw new InvalidConstantTensorException(parser, String.format("Cannot use blocks format:[] for constant of type %s", tensorType.toString()));
+ }
+ while (parser.nextToken() != JsonToken.END_ARRAY) {
+ assertCurrentTokenIs(JsonToken.START_OBJECT);
+ boolean seenAddress = false;
+ boolean seenValues = false;
+ for (int i = 0; i < 2; i++) {
+ assertNextTokenIs(JsonToken.FIELD_NAME);
+ String fieldName = parser.getCurrentName();
+ switch (fieldName) {
+ case FIELD_ADDRESS:
+ validateTensorAddress(new HashSet<>(mappedDims));
+ seenAddress = true;
+ break;
+ case FIELD_VALUES:
+ assertNextTokenIs(JsonToken.START_ARRAY);
+ consumeValuesArray();
+ seenValues = true;
+ break;
+ default:
+ throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a block object",
+ FIELD_ADDRESS, FIELD_VALUES));
+ }
+ }
+ if (! seenAddress) {
+ throw new InvalidConstantTensorException(parser, String.format("Missing '%s' field in block object", FIELD_ADDRESS));
+ }
+ if (! seenValues) {
+ throw new InvalidConstantTensorException(parser, String.format("Missing '%s' field in block object", FIELD_VALUES));
+ }
+ assertNextTokenIs(JsonToken.END_OBJECT);
+ }
+ }
+
+ private void consumeBlocksObject() throws IOException {
+ if (numMappedDims > 1 || ! isMixed()) {
+ throw new InvalidConstantTensorException(parser, String.format("Cannot use blocks:{} format for constant of type %s", tensorType.toString()));
+ }
+ while (parser.nextToken() != JsonToken.END_OBJECT) {
+ assertCurrentTokenIs(JsonToken.FIELD_NAME);
+ assertNextTokenIs(JsonToken.START_ARRAY);
+ consumeValuesArray();
+ }
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java
index bc91c7cbad2..b5fc41eaac9 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java
@@ -59,9 +59,7 @@ public class ConstantValidator extends Validator {
}
ApplicationFile tensorApplicationFile = application.getFile(Path.fromString(constantFile));
- new ConstantTensorJsonValidator().validate(constantFile,
- rankingConstant.type(),
- tensorApplicationFile.createReader());
+ new ConstantTensorJsonValidator(rankingConstant.type()).validate(constantFile, tensorApplicationFile.createReader());
}
private static class ExceptionMessageCollector {
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
index 42be1592eca..747315c1fdf 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
@@ -19,7 +19,7 @@ public class ConstantTensorJsonValidatorTest {
}
private static void validateTensorJson(TensorType tensorType, Reader jsonTensorReader) {
- new ConstantTensorJsonValidator().validate("dummy.json", tensorType, jsonTensorReader);
+ new ConstantTensorJsonValidator(tensorType).validate("dummy.json", jsonTensorReader);
}
@Test
@@ -207,8 +207,8 @@ public class ConstantTensorJsonValidatorTest {
" }",
" ]",
"}"));
- });
- assertTrue(exception.getMessage().contains("Tensor value is not a number (VALUE_STRING)"));
+ });
+ assertTrue(exception.getMessage().contains("Inside 'value': cell value is not a number (VALUE_STRING)"));
}
@Test
@@ -281,8 +281,7 @@ public class ConstantTensorJsonValidatorTest {
" }",
"}"));
});
- System.err.println("msg: " + exception.getMessage());
- assertTrue(exception.getMessage().contains("Expected 'cells' or 'values', got 'stats'"));
+ assertTrue(exception.getMessage().contains("Unexpected content '{' for field 'stats'"));
}
@Test
@@ -302,4 +301,81 @@ public class ConstantTensorJsonValidatorTest {
inputJsonToReader("{'cells':{'a':5,'b':4.0,'c':3.1,'d':-2,'e':-1.0}}"));
}
+ @Test
+ void ensure_that_matrices_work() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x[2], y[3])"),
+ inputJsonToReader(
+ "[",
+ " [ 1, 2, 3],",
+ " [ 4, 5, 6]",
+ "]"));
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x[2], y[3])"),
+ inputJsonToReader(
+ "{'values':[",
+ " [ 1, 2, 3],",
+ " [ 4, 5, 6]",
+ "]}"));
+ }
+
+ @Test
+ void ensure_that_simple_maps_work() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(category{})"),
+ inputJsonToReader(
+ "{",
+ " 'foo': 1,",
+ " 'bar': 2,",
+ " 'type': 3,",
+ " 'cells': 4,",
+ " 'value': 5,",
+ " 'values': 6,",
+ " 'blocks': 7,",
+ " 'anything': 8",
+ "}"));
+ validateTensorJson(
+ TensorType.fromSpec("tensor(category{})"),
+ inputJsonToReader(
+ "{'cells':{",
+ " 'foo': 1,",
+ " 'bar': 2,",
+ " 'type': 3,",
+ " 'cells': 4,",
+ " 'value': 5,",
+ " 'values': 6,",
+ " 'blocks': 7,",
+ " 'anything': 8",
+ "}}"));
+ }
+
+ @Test
+ void ensure_that_mixing_formats_disallowed() {
+ Throwable exception = assertThrows(InvalidConstantTensorException.class, () -> {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x{})"),
+ inputJsonToReader("{ 'a': 1.0, 'cells': { 'b': 2.0 } }"));
+
+ });
+ assertTrue(exception.getMessage().contains("Cannot use {label: value} format together with 'cells'"));
+ }
+
+ @Test
+ void ensure_that_simple_blocks_work() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(a{},b[3])"),
+ inputJsonToReader(
+ "{'blocks':{'foo':[1,2,3], 'bar':[4,5,6]}}"));
+ }
+
+ @Test
+ void ensure_that_complex_blocks_work() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(a{},b[3],c{},d[2])"),
+ inputJsonToReader(
+ "{'blocks':[",
+ "{'address':{'a':'foo','c':'bar'},'values':[[1,2],[3,4],[5,6]]},",
+ "{'address':{'a':'qux','c':'zip'},'values':[[9,8],[7,6],[5,4]]}]}"));
+ }
+
}