diff options
Diffstat (limited to 'config-model')
11 files changed, 477 insertions, 147 deletions
diff --git a/config-model/pom.xml b/config-model/pom.xml index a634a40c93b..afb5a7f49c0 100644 --- a/config-model/pom.xml +++ b/config-model/pom.xml @@ -300,6 +300,11 @@ <groupId>com.yahoo.vespa</groupId> <artifactId>bundle-plugin</artifactId> <extensions>true</extensions> + <configuration> + <allowEmbeddedArtifacts>com.fasterxml.jackson.core:jackson-annotations, com.fasterxml.jackson.core:jackson-core, + com.fasterxml.jackson.core:jackson-databind, com.yahoo.vespa:metrics, com.yahoo.vespa:predicate-search-core, + com.yahoo.vespa:searchcore</allowEmbeddedArtifacts> + </configuration> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> diff --git a/config-model/src/main/java/com/yahoo/schema/derived/AttributeFields.java b/config-model/src/main/java/com/yahoo/schema/derived/AttributeFields.java index c3531d03d3f..12ca67bf2c9 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/AttributeFields.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/AttributeFields.java @@ -51,9 +51,9 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce if (unsupportedFieldType(field)) { return; // Ignore complex struct and map fields for indexed search (only supported for streaming search) } - if (isArrayOfSimpleStruct(field)) { + if (isArrayOfSimpleStruct(field, false)) { deriveArrayOfSimpleStruct(field); - } else if (isMapOfSimpleStruct(field)) { + } else if (isMapOfSimpleStruct(field, false)) { deriveMapOfSimpleStruct(field); } else if (isMapOfPrimitiveType(field)) { deriveMapOfPrimitiveType(field); diff --git a/config-model/src/main/java/com/yahoo/schema/derived/ImportedFields.java b/config-model/src/main/java/com/yahoo/schema/derived/ImportedFields.java index fa3f49f06d5..122048d02b9 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/ImportedFields.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/ImportedFields.java @@ -61,9 +61,9 @@ public class ImportedFields extends Derived implements ImportedFieldsConfig.Prod ImmutableSDField targetField = field.targetField(); if (GeoPos.isAnyPos(targetField)) { // no action needed - } else if (isArrayOfSimpleStruct(targetField)) { + } else if (isArrayOfSimpleStruct(targetField, false)) { considerNestedFields(builder, field); - } else if (isMapOfSimpleStruct(targetField)) { + } else if (isMapOfSimpleStruct(targetField, false)) { considerSimpleField(builder, field.getNestedField("key")); considerNestedFields(builder, field.getNestedField("value")); } else if (isMapOfPrimitiveType(targetField)) { diff --git a/config-model/src/main/java/com/yahoo/schema/document/ComplexAttributeFieldUtils.java b/config-model/src/main/java/com/yahoo/schema/document/ComplexAttributeFieldUtils.java index 993c9180f78..5e4ee6d4b27 100644 --- a/config-model/src/main/java/com/yahoo/schema/document/ComplexAttributeFieldUtils.java +++ b/config-model/src/main/java/com/yahoo/schema/document/ComplexAttributeFieldUtils.java @@ -22,26 +22,31 @@ import com.yahoo.document.StructDataType; public class ComplexAttributeFieldUtils { public static boolean isSupportedComplexField(ImmutableSDField field) { - return (isArrayOfSimpleStruct(field) || - isMapOfSimpleStruct(field) || + return isSupportedComplexField(field, false); + } + + // TODO: Remove the stricterValidation flag when this is changed to being always on. + public static boolean isSupportedComplexField(ImmutableSDField field, boolean stricterValidation) { + return (isArrayOfSimpleStruct(field, stricterValidation) || + isMapOfSimpleStruct(field, stricterValidation) || isMapOfPrimitiveType(field)); } - public static boolean isArrayOfSimpleStruct(ImmutableSDField field) { + public static boolean isArrayOfSimpleStruct(ImmutableSDField field, boolean stricterValidation) { if (field.getDataType() instanceof ArrayDataType) { ArrayDataType arrayType = (ArrayDataType)field.getDataType(); - return isStructWithPrimitiveStructFieldAttributes(arrayType.getNestedType(), field); + return isStructWithPrimitiveStructFieldAttributes(arrayType.getNestedType(), field, stricterValidation); } else { return false; } } - public static boolean isMapOfSimpleStruct(ImmutableSDField field) { + public static boolean isMapOfSimpleStruct(ImmutableSDField field, boolean stricterValidation) { if (field.getDataType() instanceof MapDataType) { MapDataType mapType = (MapDataType)field.getDataType(); return isPrimitiveType(mapType.getKeyType()) && isStructWithPrimitiveStructFieldAttributes(mapType.getValueType(), - field.getStructField("value")); + field.getStructField("value"), stricterValidation); } else { return false; } @@ -57,7 +62,7 @@ public class ComplexAttributeFieldUtils { } } - private static boolean isStructWithPrimitiveStructFieldAttributes(DataType type, ImmutableSDField field) { + private static boolean isStructWithPrimitiveStructFieldAttributes(DataType type, ImmutableSDField field, boolean stricterValidation) { if (type instanceof StructDataType && ! GeoPos.isPos(type)) { for (ImmutableSDField structField : field.getStructFields()) { Attribute attribute = structField.getAttributes().get(structField.getName()); @@ -70,6 +75,9 @@ public class ComplexAttributeFieldUtils { return false; } } + if (stricterValidation && !structField.isImportedField() && hasStructFieldAttributes(structField)) { + return false; + } } return true; } else { @@ -77,6 +85,19 @@ public class ComplexAttributeFieldUtils { } } + private static boolean hasStructFieldAttributes(ImmutableSDField field) { + for (var structField : field.getStructFields()) { + var attribute = structField.getAttributes().get(structField.getName()); + if (attribute != null) { + return true; + } + if (hasStructFieldAttributes(structField)) { + return true; + } + } + return false; + } + public static boolean isPrimitiveType(Attribute attribute) { return attribute.getCollectionType().equals(Attribute.CollectionType.SINGLE) && isPrimitiveType(attribute.getDataType()); @@ -92,9 +113,9 @@ public class ComplexAttributeFieldUtils { } public static boolean isComplexFieldWithOnlyStructFieldAttributes(ImmutableSDField field) { - if (isArrayOfSimpleStruct(field)) { + if (isArrayOfSimpleStruct(field, false)) { return hasOnlyStructFieldAttributes(field); - } else if (isMapOfSimpleStruct(field)) { + } else if (isMapOfSimpleStruct(field, false)) { return (field.getStructField("key").hasSingleAttribute()) && hasOnlyStructFieldAttributes(field.getStructField("value")); } else if (isMapOfPrimitiveType(field)) { diff --git a/config-model/src/main/java/com/yahoo/schema/processing/ImportedFieldsResolver.java b/config-model/src/main/java/com/yahoo/schema/processing/ImportedFieldsResolver.java index ee465be44f2..8e44bd026a3 100644 --- a/config-model/src/main/java/com/yahoo/schema/processing/ImportedFieldsResolver.java +++ b/config-model/src/main/java/com/yahoo/schema/processing/ImportedFieldsResolver.java @@ -52,9 +52,9 @@ public class ImportedFieldsResolver extends Processor { ImmutableSDField targetField = getTargetField(importedField, reference); if (GeoPos.isAnyPos(targetField)) { resolveImportedPositionField(importedField, reference, targetField, validate); - } else if (isArrayOfSimpleStruct(targetField)) { + } else if (isArrayOfSimpleStruct(targetField, false)) { resolveImportedArrayOfStructField(importedField, reference, targetField, validate); - } else if (isMapOfSimpleStruct(targetField)) { + } else if (isMapOfSimpleStruct(targetField, false)) { resolveImportedMapOfStructField(importedField, reference, targetField, validate); } else if (isMapOfPrimitiveType(targetField)) { resolveImportedMapOfPrimitiveField(importedField, reference, targetField, validate); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ComplexFieldsWithStructFieldAttributesValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ComplexFieldsWithStructFieldAttributesValidator.java index 8515c34a377..d2999a24775 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ComplexFieldsWithStructFieldAttributesValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ComplexFieldsWithStructFieldAttributesValidator.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation; +import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.schema.Schema; import com.yahoo.schema.derived.SchemaInfo; @@ -13,6 +14,7 @@ import com.yahoo.vespa.model.search.SearchCluster; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.logging.Level; import java.util.stream.Collectors; /** @@ -31,34 +33,44 @@ public class ComplexFieldsWithStructFieldAttributesValidator extends Validator { if (cluster.isStreaming()) continue; for (SchemaInfo spec : cluster.schemas().values()) { - validateComplexFields(cluster.getClusterName(), spec.fullSchema()); + validateComplexFields(cluster.getClusterName(), spec.fullSchema(), deployState.getDeployLogger()); } } } - private static void validateComplexFields(String clusterName, Schema schema) { - String unsupportedFields = schema.allFields() - .filter(field -> isUnsupportedComplexField(field)) - .map(ComplexFieldsWithStructFieldAttributesValidator::toString) - .collect(Collectors.joining(", ")); - + private static void validateComplexFields(String clusterName, Schema schema, DeployLogger logger) { + String unsupportedFields = validateComplexFields(clusterName, schema, false); + if (!unsupportedFields.isEmpty()) { + throw new IllegalArgumentException(getErrorMessage(clusterName, schema, unsupportedFields)); + } + unsupportedFields = validateComplexFields(clusterName, schema, true); if (!unsupportedFields.isEmpty()) { - throw new IllegalArgumentException( - String.format("For cluster '%s', search '%s': The following complex fields do not support using struct field attributes: %s. " + - "Only supported for the following complex field types: array or map of struct with primitive types, map of primitive types. " + - "The supported primitive types are: byte, int, long, float, double and string", - clusterName, schema.getName(), unsupportedFields)); + logger.logApplicationPackage(Level.WARNING, getErrorMessage(clusterName, schema, unsupportedFields)); } } - private static boolean isUnsupportedComplexField(ImmutableSDField field) { + private static String validateComplexFields(String clusterName, Schema schema, boolean stricterValidation) { + return schema.allFields() + .filter(field -> isUnsupportedComplexField(field, stricterValidation)) + .map(ComplexFieldsWithStructFieldAttributesValidator::toString) + .collect(Collectors.joining(", ")); + } + + private static String getErrorMessage(String clusterName, Schema schema, String unsupportedFields) { + return String.format("For cluster '%s', search '%s': The following complex fields do not support using struct field attributes: %s. " + + "Only supported for the following complex field types: array or map of struct with primitive types, map of primitive types. " + + "The supported primitive types are: byte, int, long, float, double and string", + clusterName, schema.getName(), unsupportedFields); + } + + private static boolean isUnsupportedComplexField(ImmutableSDField field, boolean stricterValidation) { return (field.usesStructOrMap() && - !isSupportedComplexField(field) && + !isSupportedComplexField(field, stricterValidation) && hasStructFieldAttributes(field.getStructFields())); } - private static boolean isSupportedComplexField(ImmutableSDField field) { - return (ComplexAttributeFieldUtils.isSupportedComplexField(field) || + private static boolean isSupportedComplexField(ImmutableSDField field, boolean stricterValidation) { + return (ComplexAttributeFieldUtils.isSupportedComplexField(field, stricterValidation) || GeoPos.isAnyPos(field)); } @@ -82,7 +94,8 @@ public class ComplexFieldsWithStructFieldAttributesValidator extends Validator { if (structField.usesStructOrMap() && structField.wasConfiguredToDoAttributing()) { result.add(structField.getName()); } - result.addAll(getStructFieldAttributes(structField.getStructFields(), returnAllTypes)); + // If we encounter struct field attributes underneath this level, those are not supported and should be returned. + result.addAll(getStructFieldAttributes(structField.getStructFields(), true)); } return result; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java index eccb6910866..df3cd4103d9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java @@ -11,6 +11,7 @@ import java.io.IOException; import java.io.Reader; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -22,6 +23,7 @@ import java.util.stream.Collectors; * ConstantTensorJsonValidator strictly validates a constant tensor in JSON format read from a Reader object * * @author Vegard Sjonfjell + * @author arnej */ public class ConstantTensorJsonValidator { @@ -29,17 +31,63 @@ public class ConstantTensorJsonValidator { private static final String FIELD_ADDRESS = "address"; private static final String FIELD_VALUE = "value"; private static final String FIELD_VALUES = "values"; + private static final String FIELD_BLOCKS = "blocks"; + private static final String FIELD_TYPE = "type"; private static final JsonFactory jsonFactory = new JsonFactory(); private JsonParser parser; - private Map<String, TensorType.Dimension> tensorDimensions; - private boolean isSingleDenseType = false; - private boolean isSingleMappedType = false; + private final TensorType tensorType; + private final Map<String, TensorType.Dimension> tensorDimensions = new HashMap<>(); + private final List<String> denseDims = new ArrayList<>(); + private final List<String> mappedDims = new ArrayList<>(); + private int numIndexedDims = 0; + private int numMappedDims = 0; + private boolean seenCells = false; + private boolean seenValues = false; + private boolean seenBlocks = false; + private boolean seenType = false; + private boolean seenSimpleMapValue = false; + + private boolean isScalar() { + return (numIndexedDims == 0 && numMappedDims == 0); + } + private boolean isDense() { + return (numIndexedDims > 0 && numMappedDims == 0); + } + private boolean isSparse() { + return (numIndexedDims == 0 && numMappedDims > 0); + } + private boolean isSingleDense() { + return (numIndexedDims == 1 && numMappedDims == 0); + } + private boolean isSingleSparse() { + return (numIndexedDims == 0 && numMappedDims == 1); + } + private boolean isMixed() { + return (numIndexedDims > 0 && numMappedDims > 0); + } - public void validate(String fileName, TensorType type, Reader tensorData) { + public ConstantTensorJsonValidator(TensorType type) { + this.tensorType = type; + for (var dim : type.dimensions()) { + tensorDimensions.put(dim.name(), dim); + switch (dim.type()) { + case mapped: + ++numMappedDims; + mappedDims.add(dim.name()); + break; + case indexedBound: + case indexedUnbound: + ++numIndexedDims; + denseDims.add(dim.name()); + } + } + } + + public void validate(String fileName, Reader tensorData) { if (fileName.endsWith(".json")) { - validateTensor(type, tensorData); + validateTensor(tensorData); } else if (fileName.endsWith(".json.lz4")) { // don't validate; the cost probably outweights the advantage @@ -53,127 +101,156 @@ public class ConstantTensorJsonValidator { } } - private void validateTensor(TensorType type, Reader tensorData) { - wrapIOException(() -> { + private void validateTensor(Reader tensorData) { + try { this.parser = jsonFactory.createParser(tensorData); - this.tensorDimensions = type - .dimensions() - .stream() - .collect(Collectors.toMap(TensorType.Dimension::name, Function.identity())); - if (type.dimensions().size() == 1) { - this.isSingleMappedType = (type.indexedSubtype() == TensorType.empty); - this.isSingleDenseType = (type.mappedSubtype() == TensorType.empty); - } var top = parser.nextToken(); - if (top == JsonToken.START_ARRAY) { + if (top == JsonToken.START_ARRAY && isDense()) { consumeValuesArray(); + return; } else if (top == JsonToken.START_OBJECT) { consumeTopObject(); + return; + } else if (isScalar()) { + if (top == JsonToken.VALUE_NUMBER_FLOAT || top == JsonToken.VALUE_NUMBER_INT) { + return; + } + } + throw new InvalidConstantTensorException( + parser, String.format("Unexpected first token '%s' for constant with type %s", + parser.getText(), tensorType.toString())); + } catch (IOException e) { + if (parser != null) { + throw new InvalidConstantTensorException(parser, e); } - }); + throw new InvalidConstantTensorException(e); + } } private void consumeValuesArray() throws IOException { - if (! isSingleDenseType) { - throw new InvalidConstantTensorException(parser, String.format("Field 'values' is only valid for simple vectors (1-d dense tensors")); - } - assertCurrentTokenIs(JsonToken.START_ARRAY); - while (parser.nextToken() != JsonToken.END_ARRAY) { - validateNumeric(parser.getCurrentToken()); - } + consumeValuesNesting(0); } private void consumeTopObject() throws IOException { - assertCurrentTokenIs(JsonToken.START_OBJECT); - assertNextTokenIs(JsonToken.FIELD_NAME); - String fieldName = parser.getCurrentName(); - if (fieldName.equals(FIELD_VALUES)) { - assertNextTokenIs(JsonToken.START_ARRAY); - consumeValuesArray(); - } else if (fieldName.equals(FIELD_CELLS)) { - consumeCellsField(); - } else { - throw new InvalidConstantTensorException(parser, String.format("Expected 'cells' or 'values', got '%s'", fieldName)); + for (var cur = parser.nextToken(); cur != JsonToken.END_OBJECT; cur = parser.nextToken()) { + assertCurrentTokenIs(JsonToken.FIELD_NAME); + String fieldName = parser.getCurrentName(); + switch (fieldName) { + case FIELD_TYPE: + consumeTypeField(); + break; + case FIELD_VALUES: + consumeValuesField(); + break; + case FIELD_CELLS: + consumeCellsField(); + break; + case FIELD_BLOCKS: + consumeBlocksField(); + break; + default: + consumeAnyField(fieldName, parser.nextToken()); + break; + } + } + if (seenSimpleMapValue) { + if (! isSingleSparse()) { + throw new InvalidConstantTensorException(parser, String.format("Cannot use {label: value} format for constant of type %s", tensorType.toString())); + } + if (seenCells || seenValues || seenBlocks || seenType) { + throw new InvalidConstantTensorException(parser, String.format("Cannot use {label: value} format together with '%s'", + (seenCells ? FIELD_CELLS : + (seenValues ? FIELD_VALUES : + (seenBlocks ? FIELD_BLOCKS : FIELD_TYPE))))); + } + } + if (seenCells) { + if (seenValues || seenBlocks) { + throw new InvalidConstantTensorException(parser, String.format("Cannot use both '%s' and '%s' at the same time", + FIELD_CELLS, (seenValues ? FIELD_VALUES : FIELD_BLOCKS))); + } + } + if (seenValues && seenBlocks) { + throw new InvalidConstantTensorException(parser, String.format("Cannot use both '%s' and '%s' at the same time", + FIELD_VALUES, FIELD_BLOCKS)); } - assertNextTokenIs(JsonToken.END_OBJECT); } private void consumeCellsField() throws IOException { - var token = parser.nextToken(); - if (token == JsonToken.START_ARRAY) { + var cur = parser.nextToken(); + if (cur == JsonToken.START_ARRAY) { consumeLiteralFormArray(); - } else if (token == JsonToken.START_OBJECT) { + seenCells = true; + } else if (cur == JsonToken.START_OBJECT) { consumeSimpleMappedObject(); + seenCells = true; } else { - throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be object or array, but got %s", token.toString())); + consumeAnyField(FIELD_BLOCKS, cur); } } private void consumeLiteralFormArray() throws IOException { while (parser.nextToken() != JsonToken.END_ARRAY) { - validateTensorCell(); + validateLiteralFormCell(); } } private void consumeSimpleMappedObject() throws IOException { - if (! isSingleMappedType) { - throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be an array of address/value objects")); + if (! isSingleSparse()) { + throw new InvalidConstantTensorException(parser, String.format("Cannot use {label: value} format for constant of type %s", tensorType.toString())); } - while (parser.nextToken() != JsonToken.END_OBJECT) { + for (var cur = parser.nextToken(); cur != JsonToken.END_OBJECT; cur = parser.nextToken()) { assertCurrentTokenIs(JsonToken.FIELD_NAME); - validateTensorCellValue(); + validateNumeric(parser.getCurrentName(), parser.nextToken()); } } - private void validateTensorCell() { - wrapIOException(() -> { - assertCurrentTokenIs(JsonToken.START_OBJECT); - - List<String> fieldNameCandidates = new ArrayList<>(Arrays.asList(FIELD_ADDRESS, FIELD_VALUE)); - for (int i = 0; i < 2; i++) { - assertNextTokenIs(JsonToken.FIELD_NAME); - String fieldName = parser.getCurrentName(); - - if (fieldNameCandidates.contains(fieldName)) { - fieldNameCandidates.remove(fieldName); - - if (fieldName.equals(FIELD_ADDRESS)) { - validateTensorAddress(); - } else if (fieldName.equals(FIELD_VALUE)) { - validateTensorCellValue(); - } - } else { - throw new InvalidConstantTensorException(parser, "Only 'address' or 'value' fields are permitted within a cell object"); - } + private void validateLiteralFormCell() throws IOException { + assertCurrentTokenIs(JsonToken.START_OBJECT); + boolean seenAddress = false; + boolean seenValue = false; + for (int i = 0; i < 2; i++) { + assertNextTokenIs(JsonToken.FIELD_NAME); + String fieldName = parser.getCurrentName(); + switch (fieldName) { + case FIELD_ADDRESS: + validateTensorAddress(new HashSet<>(tensorDimensions.keySet())); + seenAddress = true; + break; + case FIELD_VALUE: + validateNumeric(FIELD_VALUE, parser.nextToken()); + seenValue = true; + break; + default: + throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a cell object", + FIELD_ADDRESS, FIELD_VALUE)); } - - assertNextTokenIs(JsonToken.END_OBJECT); - }); + } + if (! seenAddress) { + throw new InvalidConstantTensorException(parser, String.format("Missing '%s' field in cell object", FIELD_ADDRESS)); + } + if (! seenValue) { + throw new InvalidConstantTensorException(parser, String.format("Missing '%s' field in cell object", FIELD_VALUE)); + } + assertNextTokenIs(JsonToken.END_OBJECT); } - private void validateTensorAddress() throws IOException { + private void validateTensorAddress(Set<String> cellDimensions) throws IOException { assertNextTokenIs(JsonToken.START_OBJECT); - - Set<String> cellDimensions = new HashSet<>(tensorDimensions.keySet()); - // Iterate within the address key, value pairs while ((parser.nextToken() != JsonToken.END_OBJECT)) { assertCurrentTokenIs(JsonToken.FIELD_NAME); - String dimensionName = parser.getCurrentName(); TensorType.Dimension dimension = tensorDimensions.get(dimensionName); if (dimension == null) { throw new InvalidConstantTensorException(parser, String.format("Tensor dimension '%s' does not exist", parser.getCurrentName())); } - if (!cellDimensions.contains(dimensionName)) { throw new InvalidConstantTensorException(parser, String.format("Duplicate tensor dimension '%s'", parser.getCurrentName())); } - cellDimensions.remove(dimensionName); validateLabel(dimension); } - if (!cellDimensions.isEmpty()) { throw new InvalidConstantTensorException(parser, String.format("Tensor address missing dimension(s) %s", Joiner.on(", ").join(cellDimensions))); } @@ -186,9 +263,9 @@ public class ConstantTensorJsonValidator { */ private void validateLabel(TensorType.Dimension dimension) throws IOException { JsonToken token = parser.nextToken(); - if (token != JsonToken.VALUE_STRING) + if (token != JsonToken.VALUE_STRING) { throw new InvalidConstantTensorException(parser, String.format("Tensor label is not a string (%s)", token.toString())); - + } if (dimension instanceof TensorType.IndexedBoundDimension) { validateBoundIndex((TensorType.IndexedBoundDimension) dimension); } else if (dimension instanceof TensorType.IndexedUnboundDimension) { @@ -196,40 +273,31 @@ public class ConstantTensorJsonValidator { } } - private void validateBoundIndex(TensorType.IndexedBoundDimension dimension) { - wrapIOException(() -> { - try { - int value = Integer.parseInt(parser.getValueAsString()); - if (value >= dimension.size().get()) - throw new InvalidConstantTensorException(parser, String.format("Index %s not within limits of bound dimension '%s'", value, dimension.name())); - } catch (NumberFormatException e) { - throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name()); - } - }); + private void validateBoundIndex(TensorType.IndexedBoundDimension dimension) throws IOException { + try { + int value = Integer.parseInt(parser.getValueAsString()); + if (value >= dimension.size().get()) + throw new InvalidConstantTensorException(parser, String.format("Index %s not within limits of bound dimension '%s'", value, dimension.name())); + } catch (NumberFormatException e) { + throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name()); + } } - private void validateUnboundIndex(TensorType.Dimension dimension) { - wrapIOException(() -> { - try { - Integer.parseInt(parser.getValueAsString()); - } catch (NumberFormatException e) { - throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name()); - } - }); + private void validateUnboundIndex(TensorType.Dimension dimension) throws IOException { + try { + Integer.parseInt(parser.getValueAsString()); + } catch (NumberFormatException e) { + throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name()); + } } private void throwCoordinateIsNotInteger(String value, String dimensionName) { throw new InvalidConstantTensorException(parser, String.format("Index '%s' for dimension '%s' is not an integer", value, dimensionName)); } - private void validateTensorCellValue() throws IOException { - JsonToken token = parser.nextToken(); - validateNumeric(token); - } - - private void validateNumeric(JsonToken token) throws IOException { + private void validateNumeric(String where, JsonToken token) throws IOException { if (token != JsonToken.VALUE_NUMBER_FLOAT && token != JsonToken.VALUE_NUMBER_INT) { - throw new InvalidConstantTensorException(parser, String.format("Tensor value is not a number (%s)", token.toString())); + throw new InvalidConstantTensorException(parser, String.format("Inside '%s': cell value is not a number (%s)", where, token.toString())); } } @@ -265,6 +333,9 @@ public class ConstantTensorJsonValidator { super("Failed to parse JSON stream " + parser.getCurrentLocation().toString(), base); } + InvalidConstantTensorException(IOException base) { + super("Failed to parse JSON stream: " + base.getMessage(), base); + } } @FunctionalInterface @@ -280,4 +351,125 @@ public class ConstantTensorJsonValidator { } } + private void consumeValuesNesting(int level) throws IOException { + assertCurrentTokenIs(JsonToken.START_ARRAY); + if (level >= denseDims.size()) { + throw new InvalidConstantTensorException( + parser, String.format("Too deep array nesting for constant with type %s", tensorType.toString())); + } + var dim = tensorDimensions.get(denseDims.get(level)); + long count = 0; + for (var cur = parser.nextToken(); cur != JsonToken.END_ARRAY; cur = parser.nextToken()) { + if (level + 1 == denseDims.size()) { + validateNumeric(FIELD_VALUES, cur); + } else if (cur == JsonToken.START_ARRAY) { + consumeValuesNesting(level + 1); + } else { + throw new InvalidConstantTensorException( + parser, String.format("Unexpected token %s '%s'", cur.toString(), parser.getText())); + } + ++count; + } + if (dim.size().isPresent()) { + var sz = dim.size().get(); + if (sz != count) { + throw new InvalidConstantTensorException( + parser, String.format("Dimension '%s' has size %d but array had %d values", dim.name(), sz, count)); + } + } + } + + private void consumeTypeField() throws IOException { + var cur = parser.nextToken(); + if (cur == JsonToken.VALUE_STRING) { + seenType = true; + } else if (isSingleSparse()) { + validateNumeric(FIELD_TYPE, cur); + seenSimpleMapValue = true; + } else { + throw new InvalidConstantTensorException( + parser, String.format("Field '%s' should contain the tensor type as a string, got %s", FIELD_TYPE, parser.getText())); + } + } + + private void consumeValuesField() throws IOException { + var cur = parser.nextToken(); + if (isDense() && cur == JsonToken.START_ARRAY) { + consumeValuesArray(); + seenValues = true; + } else { + consumeAnyField(FIELD_VALUES, cur); + } + } + + private void consumeBlocksField() throws IOException { + var cur = parser.nextToken(); + if (cur == JsonToken.START_ARRAY) { + consumeBlocksArray(); + seenBlocks = true; + } else if (cur == JsonToken.START_OBJECT) { + consumeBlocksObject(); + seenBlocks = true; + } else { + consumeAnyField(FIELD_BLOCKS, cur); + } + } + + private void consumeAnyField(String fieldName, JsonToken cur) throws IOException { + if (isSingleSparse()) { + validateNumeric(FIELD_CELLS, cur); + seenSimpleMapValue = true; + } else { + throw new InvalidConstantTensorException( + parser, String.format("Unexpected content '%s' for field '%s'", parser.getText(), fieldName)); + } + } + + private void consumeBlocksArray() throws IOException { + if (! isMixed()) { + throw new InvalidConstantTensorException(parser, String.format("Cannot use blocks format:[] for constant of type %s", tensorType.toString())); + } + while (parser.nextToken() != JsonToken.END_ARRAY) { + assertCurrentTokenIs(JsonToken.START_OBJECT); + boolean seenAddress = false; + boolean seenValues = false; + for (int i = 0; i < 2; i++) { + assertNextTokenIs(JsonToken.FIELD_NAME); + String fieldName = parser.getCurrentName(); + switch (fieldName) { + case FIELD_ADDRESS: + validateTensorAddress(new HashSet<>(mappedDims)); + seenAddress = true; + break; + case FIELD_VALUES: + assertNextTokenIs(JsonToken.START_ARRAY); + consumeValuesArray(); + seenValues = true; + break; + default: + throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a block object", + FIELD_ADDRESS, FIELD_VALUES)); + } + } + if (! seenAddress) { + throw new InvalidConstantTensorException(parser, String.format("Missing '%s' field in block object", FIELD_ADDRESS)); + } + if (! seenValues) { + throw new InvalidConstantTensorException(parser, String.format("Missing '%s' field in block object", FIELD_VALUES)); + } + assertNextTokenIs(JsonToken.END_OBJECT); + } + } + + private void consumeBlocksObject() throws IOException { + if (numMappedDims > 1 || ! isMixed()) { + throw new InvalidConstantTensorException(parser, String.format("Cannot use blocks:{} format for constant of type %s", tensorType.toString())); + } + while (parser.nextToken() != JsonToken.END_OBJECT) { + assertCurrentTokenIs(JsonToken.FIELD_NAME); + assertNextTokenIs(JsonToken.START_ARRAY); + consumeValuesArray(); + } + } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java index bc91c7cbad2..b5fc41eaac9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java @@ -59,9 +59,7 @@ public class ConstantValidator extends Validator { } ApplicationFile tensorApplicationFile = application.getFile(Path.fromString(constantFile)); - new ConstantTensorJsonValidator().validate(constantFile, - rankingConstant.type(), - tensorApplicationFile.createReader()); + new ConstantTensorJsonValidator(rankingConstant.type()).validate(constantFile, tensorApplicationFile.createReader()); } private static class ExceptionMessageCollector { diff --git a/config-model/src/test/java/com/yahoo/schema/document/ComplexAttributeFieldUtilsTestCase.java b/config-model/src/test/java/com/yahoo/schema/document/ComplexAttributeFieldUtilsTestCase.java index 7a89f52268f..310ede6bae2 100644 --- a/config-model/src/test/java/com/yahoo/schema/document/ComplexAttributeFieldUtilsTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/document/ComplexAttributeFieldUtilsTestCase.java @@ -30,11 +30,11 @@ public class ComplexAttributeFieldUtilsTestCase { } boolean isArrayOfSimpleStruct() { - return ComplexAttributeFieldUtils.isArrayOfSimpleStruct(field()); + return ComplexAttributeFieldUtils.isArrayOfSimpleStruct(field(), false); } boolean isMapOfSimpleStruct() { - return ComplexAttributeFieldUtils.isMapOfSimpleStruct(field()); + return ComplexAttributeFieldUtils.isMapOfSimpleStruct(field(), false); } boolean isMapOfPrimitiveType() { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ComplexFieldsValidatorTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ComplexFieldsValidatorTestCase.java index c673d5899e8..04abd4e4836 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ComplexFieldsValidatorTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ComplexFieldsValidatorTestCase.java @@ -18,7 +18,6 @@ import java.util.List; import java.util.logging.Level; import static com.yahoo.config.model.test.TestUtil.joinLines; -import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -68,7 +67,33 @@ public class ComplexFieldsValidatorTestCase { "}", "}")); }); - assertTrue(exception.getMessage().contains(getExpectedMessage("docTopics (docTopics.topics)"))); + assertTrue(exception.getMessage().contains(getExpectedMessage("docTopics (docTopics.topics, docTopics.topics.id, docTopics.topics.label)"))); + } + + @Test + void logs_warning_when_struct_field_inside_nested_struct_array_is_specified_as_attribute() throws IOException, SAXException { + var logger = new MyLogger(); + createModelAndValidate(joinLines( + "schema test {", + "document test {", + "struct item {", + "field name type string {}", + "field color type string {}", + "field type type string {}", + "}", + "struct itembox {", + "field items type array<item> {}", + "}", + "field cabinet type map<string, itembox> {", + "struct-field key { indexing: attribute }", + "struct-field value.items {", + "struct-field name { indexing: attribute }", + "struct-field color { indexing: attribute }", + "}", + "}", + "}", + "}"), logger); + assertTrue(logger.message.toString().contains(getExpectedMessage("cabinet (cabinet.value.items.name, cabinet.value.items.color)"))); } private String getExpectedMessage(String unsupportedFields) { @@ -105,7 +130,7 @@ public class ComplexFieldsValidatorTestCase { "}", "}", "}"), logger); - assertThat(logger.message.toString().contains( + assertTrue(logger.message.toString().contains( "For cluster 'mycluster', schema 'test': " + "The following complex fields have struct fields with 'indexing: index' which is not supported and has no effect: " + "topics (topics.id, topics.label). " + diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java index 42be1592eca..747315c1fdf 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java @@ -19,7 +19,7 @@ public class ConstantTensorJsonValidatorTest { } private static void validateTensorJson(TensorType tensorType, Reader jsonTensorReader) { - new ConstantTensorJsonValidator().validate("dummy.json", tensorType, jsonTensorReader); + new ConstantTensorJsonValidator(tensorType).validate("dummy.json", jsonTensorReader); } @Test @@ -207,8 +207,8 @@ public class ConstantTensorJsonValidatorTest { " }", " ]", "}")); - }); - assertTrue(exception.getMessage().contains("Tensor value is not a number (VALUE_STRING)")); + }); + assertTrue(exception.getMessage().contains("Inside 'value': cell value is not a number (VALUE_STRING)")); } @Test @@ -281,8 +281,7 @@ public class ConstantTensorJsonValidatorTest { " }", "}")); }); - System.err.println("msg: " + exception.getMessage()); - assertTrue(exception.getMessage().contains("Expected 'cells' or 'values', got 'stats'")); + assertTrue(exception.getMessage().contains("Unexpected content '{' for field 'stats'")); } @Test @@ -302,4 +301,81 @@ public class ConstantTensorJsonValidatorTest { inputJsonToReader("{'cells':{'a':5,'b':4.0,'c':3.1,'d':-2,'e':-1.0}}")); } + @Test + void ensure_that_matrices_work() { + validateTensorJson( + TensorType.fromSpec("tensor(x[2], y[3])"), + inputJsonToReader( + "[", + " [ 1, 2, 3],", + " [ 4, 5, 6]", + "]")); + validateTensorJson( + TensorType.fromSpec("tensor(x[2], y[3])"), + inputJsonToReader( + "{'values':[", + " [ 1, 2, 3],", + " [ 4, 5, 6]", + "]}")); + } + + @Test + void ensure_that_simple_maps_work() { + validateTensorJson( + TensorType.fromSpec("tensor(category{})"), + inputJsonToReader( + "{", + " 'foo': 1,", + " 'bar': 2,", + " 'type': 3,", + " 'cells': 4,", + " 'value': 5,", + " 'values': 6,", + " 'blocks': 7,", + " 'anything': 8", + "}")); + validateTensorJson( + TensorType.fromSpec("tensor(category{})"), + inputJsonToReader( + "{'cells':{", + " 'foo': 1,", + " 'bar': 2,", + " 'type': 3,", + " 'cells': 4,", + " 'value': 5,", + " 'values': 6,", + " 'blocks': 7,", + " 'anything': 8", + "}}")); + } + + @Test + void ensure_that_mixing_formats_disallowed() { + Throwable exception = assertThrows(InvalidConstantTensorException.class, () -> { + validateTensorJson( + TensorType.fromSpec("tensor(x{})"), + inputJsonToReader("{ 'a': 1.0, 'cells': { 'b': 2.0 } }")); + + }); + assertTrue(exception.getMessage().contains("Cannot use {label: value} format together with 'cells'")); + } + + @Test + void ensure_that_simple_blocks_work() { + validateTensorJson( + TensorType.fromSpec("tensor(a{},b[3])"), + inputJsonToReader( + "{'blocks':{'foo':[1,2,3], 'bar':[4,5,6]}}")); + } + + @Test + void ensure_that_complex_blocks_work() { + validateTensorJson( + TensorType.fromSpec("tensor(a{},b[3],c{},d[2])"), + inputJsonToReader( + "{'blocks':[", + "{'address':{'a':'foo','c':'bar'},'values':[[1,2],[3,4],[5,6]]},", + "{'address':{'a':'qux','c':'zip'},'values':[[9,8],[7,6],[5,4]]}]}")); + } + } |