aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-05-22 11:42:17 +0200
committerGitHub <noreply@github.com>2023-05-22 11:42:17 +0200
commit07d0699378a589bcdffa08bba840b695a185ef99 (patch)
treeae4a6057d52f77bd2a14fab214774ffa90b86a03 /config-model/src/main
parentc7a07adf43c13165e49e2aa2ef509ecb2526a48c (diff)
parent95d2e5194acd87facd594201dd1db254a41b1f73 (diff)
Merge pull request #27137 from vespa-engine/arnej/allow-some-short-forms-for-constant-tensors
allow short-form JSON for 1-d constants
Diffstat (limited to 'config-model/src/main')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java76
1 files changed, 66 insertions, 10 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
index 66da43856b1..eccb6910866 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
@@ -28,11 +28,14 @@ public class ConstantTensorJsonValidator {
private static final String FIELD_CELLS = "cells";
private static final String FIELD_ADDRESS = "address";
private static final String FIELD_VALUE = "value";
+ private static final String FIELD_VALUES = "values";
private static final JsonFactory jsonFactory = new JsonFactory();
private JsonParser parser;
private Map<String, TensorType.Dimension> tensorDimensions;
+ private boolean isSingleDenseType = false;
+ private boolean isSingleMappedType = false;
public void validate(String fileName, TensorType type, Reader tensorData) {
if (fileName.endsWith(".json")) {
@@ -57,19 +60,69 @@ public class ConstantTensorJsonValidator {
.dimensions()
.stream()
.collect(Collectors.toMap(TensorType.Dimension::name, Function.identity()));
+ if (type.dimensions().size() == 1) {
+ this.isSingleMappedType = (type.indexedSubtype() == TensorType.empty);
+ this.isSingleDenseType = (type.mappedSubtype() == TensorType.empty);
+ }
+ var top = parser.nextToken();
+ if (top == JsonToken.START_ARRAY) {
+ consumeValuesArray();
+ } else if (top == JsonToken.START_OBJECT) {
+ consumeTopObject();
+ }
+ });
+ }
- assertNextTokenIs(JsonToken.START_OBJECT);
- assertNextTokenIs(JsonToken.FIELD_NAME);
- assertFieldNameIs(FIELD_CELLS);
+ private void consumeValuesArray() throws IOException {
+ if (! isSingleDenseType) {
+ throw new InvalidConstantTensorException(parser, String.format("Field 'values' is only valid for simple vectors (1-d dense tensors"));
+ }
+ assertCurrentTokenIs(JsonToken.START_ARRAY);
+ while (parser.nextToken() != JsonToken.END_ARRAY) {
+ validateNumeric(parser.getCurrentToken());
+ }
+ }
+ private void consumeTopObject() throws IOException {
+ assertCurrentTokenIs(JsonToken.START_OBJECT);
+ assertNextTokenIs(JsonToken.FIELD_NAME);
+ String fieldName = parser.getCurrentName();
+ if (fieldName.equals(FIELD_VALUES)) {
assertNextTokenIs(JsonToken.START_ARRAY);
+ consumeValuesArray();
+ } else if (fieldName.equals(FIELD_CELLS)) {
+ consumeCellsField();
+ } else {
+ throw new InvalidConstantTensorException(parser, String.format("Expected 'cells' or 'values', got '%s'", fieldName));
+ }
+ assertNextTokenIs(JsonToken.END_OBJECT);
+ }
- while (parser.nextToken() != JsonToken.END_ARRAY) {
- validateTensorCell();
- }
+ private void consumeCellsField() throws IOException {
+ var token = parser.nextToken();
+ if (token == JsonToken.START_ARRAY) {
+ consumeLiteralFormArray();
+ } else if (token == JsonToken.START_OBJECT) {
+ consumeSimpleMappedObject();
+ } else {
+ throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be object or array, but got %s", token.toString()));
+ }
+ }
- assertNextTokenIs(JsonToken.END_OBJECT);
- });
+ private void consumeLiteralFormArray() throws IOException {
+ while (parser.nextToken() != JsonToken.END_ARRAY) {
+ validateTensorCell();
+ }
+ }
+
+ private void consumeSimpleMappedObject() throws IOException {
+ if (! isSingleMappedType) {
+ throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be an array of address/value objects"));
+ }
+ while (parser.nextToken() != JsonToken.END_OBJECT) {
+ assertCurrentTokenIs(JsonToken.FIELD_NAME);
+ validateTensorCellValue();
+ }
}
private void validateTensorCell() {
@@ -87,7 +140,7 @@ public class ConstantTensorJsonValidator {
if (fieldName.equals(FIELD_ADDRESS)) {
validateTensorAddress();
} else if (fieldName.equals(FIELD_VALUE)) {
- validateTensorValue();
+ validateTensorCellValue();
}
} else {
throw new InvalidConstantTensorException(parser, "Only 'address' or 'value' fields are permitted within a cell object");
@@ -169,9 +222,12 @@ public class ConstantTensorJsonValidator {
throw new InvalidConstantTensorException(parser, String.format("Index '%s' for dimension '%s' is not an integer", value, dimensionName));
}
- private void validateTensorValue() throws IOException {
+ private void validateTensorCellValue() throws IOException {
JsonToken token = parser.nextToken();
+ validateNumeric(token);
+ }
+ private void validateNumeric(JsonToken token) throws IOException {
if (token != JsonToken.VALUE_NUMBER_FLOAT && token != JsonToken.VALUE_NUMBER_INT) {
throw new InvalidConstantTensorException(parser, String.format("Tensor value is not a number (%s)", token.toString()));
}