From db2de9140b9487e8aa26bbd4ba1aedcad4ecc990 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Tue, 16 May 2023 11:06:12 +0000 Subject: allow short-form JSON for 1-d constants --- .../validation/ConstantTensorJsonValidator.java | 76 +++++++++++++++++++--- 1 file changed, 66 insertions(+), 10 deletions(-) (limited to 'config-model/src/main') diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java index 66da43856b1..eccb6910866 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java @@ -28,11 +28,14 @@ public class ConstantTensorJsonValidator { private static final String FIELD_CELLS = "cells"; private static final String FIELD_ADDRESS = "address"; private static final String FIELD_VALUE = "value"; + private static final String FIELD_VALUES = "values"; private static final JsonFactory jsonFactory = new JsonFactory(); private JsonParser parser; private Map tensorDimensions; + private boolean isSingleDenseType = false; + private boolean isSingleMappedType = false; public void validate(String fileName, TensorType type, Reader tensorData) { if (fileName.endsWith(".json")) { @@ -57,19 +60,69 @@ public class ConstantTensorJsonValidator { .dimensions() .stream() .collect(Collectors.toMap(TensorType.Dimension::name, Function.identity())); + if (type.dimensions().size() == 1) { + this.isSingleMappedType = (type.indexedSubtype() == TensorType.empty); + this.isSingleDenseType = (type.mappedSubtype() == TensorType.empty); + } + var top = parser.nextToken(); + if (top == JsonToken.START_ARRAY) { + consumeValuesArray(); + } else if (top == JsonToken.START_OBJECT) { + consumeTopObject(); + } + }); + } - assertNextTokenIs(JsonToken.START_OBJECT); - assertNextTokenIs(JsonToken.FIELD_NAME); - assertFieldNameIs(FIELD_CELLS); + private void consumeValuesArray() throws IOException { + if (! isSingleDenseType) { + throw new InvalidConstantTensorException(parser, String.format("Field 'values' is only valid for simple vectors (1-d dense tensors")); + } + assertCurrentTokenIs(JsonToken.START_ARRAY); + while (parser.nextToken() != JsonToken.END_ARRAY) { + validateNumeric(parser.getCurrentToken()); + } + } + private void consumeTopObject() throws IOException { + assertCurrentTokenIs(JsonToken.START_OBJECT); + assertNextTokenIs(JsonToken.FIELD_NAME); + String fieldName = parser.getCurrentName(); + if (fieldName.equals(FIELD_VALUES)) { assertNextTokenIs(JsonToken.START_ARRAY); + consumeValuesArray(); + } else if (fieldName.equals(FIELD_CELLS)) { + consumeCellsField(); + } else { + throw new InvalidConstantTensorException(parser, String.format("Expected 'cells' or 'values', got '%s'", fieldName)); + } + assertNextTokenIs(JsonToken.END_OBJECT); + } - while (parser.nextToken() != JsonToken.END_ARRAY) { - validateTensorCell(); - } + private void consumeCellsField() throws IOException { + var token = parser.nextToken(); + if (token == JsonToken.START_ARRAY) { + consumeLiteralFormArray(); + } else if (token == JsonToken.START_OBJECT) { + consumeSimpleMappedObject(); + } else { + throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be object or array, but got %s", token.toString())); + } + } - assertNextTokenIs(JsonToken.END_OBJECT); - }); + private void consumeLiteralFormArray() throws IOException { + while (parser.nextToken() != JsonToken.END_ARRAY) { + validateTensorCell(); + } + } + + private void consumeSimpleMappedObject() throws IOException { + if (! isSingleMappedType) { + throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be an array of address/value objects")); + } + while (parser.nextToken() != JsonToken.END_OBJECT) { + assertCurrentTokenIs(JsonToken.FIELD_NAME); + validateTensorCellValue(); + } } private void validateTensorCell() { @@ -87,7 +140,7 @@ public class ConstantTensorJsonValidator { if (fieldName.equals(FIELD_ADDRESS)) { validateTensorAddress(); } else if (fieldName.equals(FIELD_VALUE)) { - validateTensorValue(); + validateTensorCellValue(); } } else { throw new InvalidConstantTensorException(parser, "Only 'address' or 'value' fields are permitted within a cell object"); @@ -169,9 +222,12 @@ public class ConstantTensorJsonValidator { throw new InvalidConstantTensorException(parser, String.format("Index '%s' for dimension '%s' is not an integer", value, dimensionName)); } - private void validateTensorValue() throws IOException { + private void validateTensorCellValue() throws IOException { JsonToken token = parser.nextToken(); + validateNumeric(token); + } + private void validateNumeric(JsonToken token) throws IOException { if (token != JsonToken.VALUE_NUMBER_FLOAT && token != JsonToken.VALUE_NUMBER_INT) { throw new InvalidConstantTensorException(parser, String.format("Tensor value is not a number (%s)", token.toString())); } -- cgit v1.2.3