diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-05-22 11:42:17 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-22 11:42:17 +0200 |
commit | 07d0699378a589bcdffa08bba840b695a185ef99 (patch) | |
tree | ae4a6057d52f77bd2a14fab214774ffa90b86a03 /config-model/src/main | |
parent | c7a07adf43c13165e49e2aa2ef509ecb2526a48c (diff) | |
parent | 95d2e5194acd87facd594201dd1db254a41b1f73 (diff) |
Merge pull request #27137 from vespa-engine/arnej/allow-some-short-forms-for-constant-tensors
allow short-form JSON for 1-d constants
Diffstat (limited to 'config-model/src/main')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java | 76 |
1 files changed, 66 insertions, 10 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java index 66da43856b1..eccb6910866 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java @@ -28,11 +28,14 @@ public class ConstantTensorJsonValidator { private static final String FIELD_CELLS = "cells"; private static final String FIELD_ADDRESS = "address"; private static final String FIELD_VALUE = "value"; + private static final String FIELD_VALUES = "values"; private static final JsonFactory jsonFactory = new JsonFactory(); private JsonParser parser; private Map<String, TensorType.Dimension> tensorDimensions; + private boolean isSingleDenseType = false; + private boolean isSingleMappedType = false; public void validate(String fileName, TensorType type, Reader tensorData) { if (fileName.endsWith(".json")) { @@ -57,19 +60,69 @@ public class ConstantTensorJsonValidator { .dimensions() .stream() .collect(Collectors.toMap(TensorType.Dimension::name, Function.identity())); + if (type.dimensions().size() == 1) { + this.isSingleMappedType = (type.indexedSubtype() == TensorType.empty); + this.isSingleDenseType = (type.mappedSubtype() == TensorType.empty); + } + var top = parser.nextToken(); + if (top == JsonToken.START_ARRAY) { + consumeValuesArray(); + } else if (top == JsonToken.START_OBJECT) { + consumeTopObject(); + } + }); + } - assertNextTokenIs(JsonToken.START_OBJECT); - assertNextTokenIs(JsonToken.FIELD_NAME); - assertFieldNameIs(FIELD_CELLS); + private void consumeValuesArray() throws IOException { + if (! isSingleDenseType) { + throw new InvalidConstantTensorException(parser, String.format("Field 'values' is only valid for simple vectors (1-d dense tensors")); + } + assertCurrentTokenIs(JsonToken.START_ARRAY); + while (parser.nextToken() != JsonToken.END_ARRAY) { + validateNumeric(parser.getCurrentToken()); + } + } + private void consumeTopObject() throws IOException { + assertCurrentTokenIs(JsonToken.START_OBJECT); + assertNextTokenIs(JsonToken.FIELD_NAME); + String fieldName = parser.getCurrentName(); + if (fieldName.equals(FIELD_VALUES)) { assertNextTokenIs(JsonToken.START_ARRAY); + consumeValuesArray(); + } else if (fieldName.equals(FIELD_CELLS)) { + consumeCellsField(); + } else { + throw new InvalidConstantTensorException(parser, String.format("Expected 'cells' or 'values', got '%s'", fieldName)); + } + assertNextTokenIs(JsonToken.END_OBJECT); + } - while (parser.nextToken() != JsonToken.END_ARRAY) { - validateTensorCell(); - } + private void consumeCellsField() throws IOException { + var token = parser.nextToken(); + if (token == JsonToken.START_ARRAY) { + consumeLiteralFormArray(); + } else if (token == JsonToken.START_OBJECT) { + consumeSimpleMappedObject(); + } else { + throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be object or array, but got %s", token.toString())); + } + } - assertNextTokenIs(JsonToken.END_OBJECT); - }); + private void consumeLiteralFormArray() throws IOException { + while (parser.nextToken() != JsonToken.END_ARRAY) { + validateTensorCell(); + } + } + + private void consumeSimpleMappedObject() throws IOException { + if (! isSingleMappedType) { + throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be an array of address/value objects")); + } + while (parser.nextToken() != JsonToken.END_OBJECT) { + assertCurrentTokenIs(JsonToken.FIELD_NAME); + validateTensorCellValue(); + } } private void validateTensorCell() { @@ -87,7 +140,7 @@ public class ConstantTensorJsonValidator { if (fieldName.equals(FIELD_ADDRESS)) { validateTensorAddress(); } else if (fieldName.equals(FIELD_VALUE)) { - validateTensorValue(); + validateTensorCellValue(); } } else { throw new InvalidConstantTensorException(parser, "Only 'address' or 'value' fields are permitted within a cell object"); @@ -169,9 +222,12 @@ public class ConstantTensorJsonValidator { throw new InvalidConstantTensorException(parser, String.format("Index '%s' for dimension '%s' is not an integer", value, dimensionName)); } - private void validateTensorValue() throws IOException { + private void validateTensorCellValue() throws IOException { JsonToken token = parser.nextToken(); + validateNumeric(token); + } + private void validateNumeric(JsonToken token) throws IOException { if (token != JsonToken.VALUE_NUMBER_FLOAT && token != JsonToken.VALUE_NUMBER_INT) { throw new InvalidConstantTensorException(parser, String.format("Tensor value is not a number (%s)", token.toString())); } |