diff options
author | Jon Bratseth <bratseth@gmail.com> | 2020-06-08 22:49:49 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2020-06-08 22:49:49 +0200 |
commit | 7f0f68d16103259bf3f2174543c6fbd3456a22fa (patch) | |
tree | 920dbb5d7b8e1bac190ed82937d6679a5c5d7e64 | |
parent | 0b1348868b7b4ec23925b20bd7bea4fa5f0d53e2 (diff) |
Disallow unbound tensor dimensions in ranking constants
8 files changed, 54 insertions, 48 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java index 7b7265e02ae..b41cf582204 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java @@ -56,7 +56,9 @@ public class RankingConstant { this.pathType = PathType.URI; } - public void setType(TensorType tensorType) { this.tensorType = tensorType; } + public void setType(TensorType type) { + this.tensorType = type; + } /** Initiate sending of this constant to some services over file distribution */ public void sendTo(Collection<? extends AbstractService> services) { @@ -78,6 +80,9 @@ public class RankingConstant { throw new IllegalArgumentException("Ranking constants must have a file or uri."); if (tensorType == null) throw new IllegalArgumentException("Ranking constant '" + name + "' must have a type."); + if (tensorType.dimensions().stream().anyMatch(d -> d.isIndexed() && d.size().isEmpty())) + throw new IllegalArgumentException("Illegal type in field " + name + " type " + tensorType + + ": Dense tensor dimensions must have a size"); } public String toString() { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java index c6bf54af760..75e6922ad15 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java @@ -90,7 +90,7 @@ public class ConstantTensorJsonValidator { validateTensorValue(); } } else { - throw new InvalidConstantTensor(parser, "Only \"address\" or \"value\" fields are permitted within a cell object"); + throw new InvalidConstantTensor(parser, "Only 'address' or 'value' fields are permitted within a cell object"); } } @@ -110,55 +110,52 @@ public class ConstantTensorJsonValidator { String dimensionName = parser.getCurrentName(); TensorType.Dimension dimension = tensorDimensions.get(dimensionName); if (dimension == null) { - throw new InvalidConstantTensor(parser, String.format("Tensor dimension \"%s\" does not exist", parser.getCurrentName())); + throw new InvalidConstantTensor(parser, String.format("Tensor dimension '%s' does not exist", parser.getCurrentName())); } if (!cellDimensions.contains(dimensionName)) { - throw new InvalidConstantTensor(parser, String.format("Duplicate tensor dimension \"%s\"", parser.getCurrentName())); + throw new InvalidConstantTensor(parser, String.format("Duplicate tensor dimension '%s'", parser.getCurrentName())); } cellDimensions.remove(dimensionName); - validateTensorCoordinate(dimension); + validateLabel(dimension); } if (!cellDimensions.isEmpty()) { - throw new InvalidConstantTensor(parser, String.format("Tensor address missing dimension(s): %s", Joiner.on(", ").join(cellDimensions))); + throw new InvalidConstantTensor(parser, String.format("Tensor address missing dimension(s) %s", Joiner.on(", ").join(cellDimensions))); } } - /* - * Tensor coordinates are always strings. Coordinates for a mapped dimension can be any string, + /** + * Tensor labels are always strings. Labels for a mapped dimension can be any string, * but those for indexed dimensions needs to be able to be interpreted as integers, and, * additionally, those for indexed bounded dimensions needs to fall within the dimension size. */ - private void validateTensorCoordinate(TensorType.Dimension dimension) throws IOException { + private void validateLabel(TensorType.Dimension dimension) throws IOException { JsonToken token = parser.nextToken(); - if (token != JsonToken.VALUE_STRING) { - throw new InvalidConstantTensor(parser, String.format("Tensor coordinate is not a string (%s)", token.toString())); - } + if (token != JsonToken.VALUE_STRING) + throw new InvalidConstantTensor(parser, String.format("Tensor label is not a string (%s)", token.toString())); if (dimension instanceof TensorType.IndexedBoundDimension) { - validateBoundedCoordinate((TensorType.IndexedBoundDimension) dimension); + validateBoundIndex((TensorType.IndexedBoundDimension) dimension); } else if (dimension instanceof TensorType.IndexedUnboundDimension) { - validateUnboundedCoordinate(dimension); + validateUnboundIndex(dimension); } } - private void validateBoundedCoordinate(TensorType.IndexedBoundDimension dimension) { + private void validateBoundIndex(TensorType.IndexedBoundDimension dimension) { wrapIOException(() -> { try { int value = Integer.parseInt(parser.getValueAsString()); - if (value >= dimension.size().get()) { - throw new InvalidConstantTensor(parser, String.format("Coordinate \"%s\" not within limits of bounded dimension %s", value, dimension.name())); - - } + if (value >= dimension.size().get()) + throw new InvalidConstantTensor(parser, String.format("Index %s not within limits of bound dimension '%s'", value, dimension.name())); } catch (NumberFormatException e) { throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name()); } }); } - private void validateUnboundedCoordinate(TensorType.Dimension dimension) { + private void validateUnboundIndex(TensorType.Dimension dimension) { wrapIOException(() -> { try { Integer.parseInt(parser.getValueAsString()); @@ -169,7 +166,7 @@ public class ConstantTensorJsonValidator { } private void throwCoordinateIsNotInteger(String value, String dimensionName) { - throw new InvalidConstantTensor(parser, String.format("Coordinate \"%s\" for dimension %s is not an integer", value, dimensionName)); + throw new InvalidConstantTensor(parser, String.format("Index '%s' for dimension '%s' is not an integer", value, dimensionName)); } private void validateTensorValue() throws IOException { @@ -198,11 +195,12 @@ public class ConstantTensorJsonValidator { String actualFieldName = parser.getCurrentName(); if (!actualFieldName.equals(wantedFieldName)) { - throw new InvalidConstantTensor(parser, String.format("Expected field name \"%s\", got \"%s\"", wantedFieldName, actualFieldName)); + throw new InvalidConstantTensor(parser, String.format("Expected field name '%s', got '%s'", wantedFieldName, actualFieldName)); } } static class InvalidConstantTensor extends RuntimeException { + InvalidConstantTensor(JsonParser parser, String message) { super(message + " " + parser.getCurrentLocation().toString()); } @@ -210,6 +208,7 @@ public class ConstantTensorJsonValidator { InvalidConstantTensor(JsonParser parser, Exception base) { super("Failed to parse JSON stream " + parser.getCurrentLocation().toString(), base); } + } @FunctionalInterface diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java index 9568ea5c27c..cf8a9201668 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java @@ -31,7 +31,7 @@ public class RankingConstantsValidator extends Validator { public ExceptionMessageCollector add(Throwable throwable, String rcName, String rcFilename) { exceptionsOccurred = true; - combinedMessage += String.format("\nRanking constant \"%s\" (%s): %s", rcName, rcFilename, throwable.getMessage()); + combinedMessage += String.format("\nRanking constant '%s' (%s): %s", rcName, rcFilename, throwable.getMessage()); return this; } } diff --git a/config-model/src/test/cfg/application/validation/ranking_constants_fail/searchdefinitions/simple.sd b/config-model/src/test/cfg/application/validation/ranking_constants_fail/searchdefinitions/simple.sd index 126f8d00724..8b782a01946 100644 --- a/config-model/src/test/cfg/application/validation/ranking_constants_fail/searchdefinitions/simple.sd +++ b/config-model/src/test/cfg/application/validation/ranking_constants_fail/searchdefinitions/simple.sd @@ -4,12 +4,12 @@ search simple { constant constant_tensor_1 { file: tensors/constant_tensor_1.json - type: tensor(x[], y[]) + type: tensor(x[4], y[3]) } constant constant_tensor_2 { file: tensors/constant_tensor_2.json - type: tensor(x[]) + type: tensor(x[6]) } constant constant_tensor_3 { @@ -24,6 +24,6 @@ search simple { constant constant_tensor_5 { file: tensors/constant_tensor_5.json - type: tensor(x[], y[], z[]) + type: tensor(x[33], y[10], z[46]) } } diff --git a/config-model/src/test/cfg/application/validation/ranking_constants_ok/searchdefinitions/simple.sd b/config-model/src/test/cfg/application/validation/ranking_constants_ok/searchdefinitions/simple.sd index 126f8d00724..8b782a01946 100644 --- a/config-model/src/test/cfg/application/validation/ranking_constants_ok/searchdefinitions/simple.sd +++ b/config-model/src/test/cfg/application/validation/ranking_constants_ok/searchdefinitions/simple.sd @@ -4,12 +4,12 @@ search simple { constant constant_tensor_1 { file: tensors/constant_tensor_1.json - type: tensor(x[], y[]) + type: tensor(x[4], y[3]) } constant constant_tensor_2 { file: tensors/constant_tensor_2.json - type: tensor(x[]) + type: tensor(x[6]) } constant constant_tensor_3 { @@ -24,6 +24,6 @@ search simple { constant constant_tensor_5 { file: tensors/constant_tensor_5.json - type: tensor(x[], y[], z[]) + type: tensor(x[33], y[10], z[46]) } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 35679ffa762..02d1c3fc3b0 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -184,19 +184,19 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { " }\n" + " }\n" + " constant W_hidden {\n" + - " type: tensor(x[])\n" + + " type: tensor(x[1])\n" + " file: ignored.json\n" + " }\n" + " constant b_input {\n" + - " type: tensor(x[])\n" + + " type: tensor(x[1])\n" + " file: ignored.json\n" + " }\n" + " constant W_final {\n" + - " type: tensor(x[])\n" + + " type: tensor(x[1])\n" + " file: ignored.json\n" + " }\n" + " constant b_final {\n" + - " type: tensor(x[])\n" + + " type: tensor(x[1])\n" + " file: ignored.json\n" + " }\n" + "}\n"); @@ -211,11 +211,11 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { censorBindingHash(testRankProperties.get(0).toString())); assertEquals("(rankingExpression(hidden_layer).rankingScript,rankingExpression(relu@))", censorBindingHash(testRankProperties.get(1).toString())); - assertEquals("(rankingExpression(hidden_layer).type,tensor(x[]))", + assertEquals("(rankingExpression(hidden_layer).type,tensor(x[1]))", censorBindingHash(testRankProperties.get(2).toString())); assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))", testRankProperties.get(3).toString()); - assertEquals("(rankingExpression(final_layer).type,tensor(x[]))", + assertEquals("(rankingExpression(final_layer).type,tensor(x[1]))", testRankProperties.get(4).toString()); assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(5).toString()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java index 1f8dcc2da64..b594a3329d1 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java @@ -86,9 +86,9 @@ public class ConstantTensorJsonValidatorTest { } @Test - public void ensure_that_bounded_tensor_outside_limits_is_disallowed() { + public void ensure_that_bound_tensor_outside_limits_is_disallowed() { expectedException.expect(InvalidConstantTensor.class); - expectedException.expectMessage("Coordinate \"5\" not within limits of bounded dimension x"); + expectedException.expectMessage("Index 5 not within limits of bound dimension 'x'"); validateTensorJson( TensorType.fromSpec("tensor(x[5], y[10])"), @@ -119,9 +119,9 @@ public class ConstantTensorJsonValidatorTest { } @Test - public void ensure_that_non_integer_strings_in_address_points_are_disallowed_unbounded() { + public void ensure_that_non_integer_strings_in_address_points_are_disallowed_unbound() { expectedException.expect(InvalidConstantTensor.class); - expectedException.expectMessage("Coordinate \"a\" for dimension x is not an integer"); + expectedException.expectMessage("Index 'a' for dimension 'x' is not an integer"); validateTensorJson( TensorType.fromSpec("tensor(x[])"), @@ -139,7 +139,7 @@ public class ConstantTensorJsonValidatorTest { @Test public void ensure_that_tensor_coordinates_are_strings() { expectedException.expect(InvalidConstantTensor.class); - expectedException.expectMessage("Tensor coordinate is not a string (VALUE_NUMBER_INT)"); + expectedException.expectMessage("Tensor label is not a string (VALUE_NUMBER_INT)"); validateTensorJson( TensorType.fromSpec("tensor(x[])"), @@ -157,7 +157,7 @@ public class ConstantTensorJsonValidatorTest { @Test public void ensure_that_non_integer_strings_in_address_points_are_disallowed_bounded() { expectedException.expect(InvalidConstantTensor.class); - expectedException.expectMessage("Coordinate \"a\" for dimension x is not an integer"); + expectedException.expectMessage("Index 'a' for dimension 'x' is not an integer"); validateTensorJson( TensorType.fromSpec("tensor(x[5])"), @@ -175,7 +175,7 @@ public class ConstantTensorJsonValidatorTest { @Test public void ensure_that_missing_coordinates_fail() { expectedException.expect(InvalidConstantTensor.class); - expectedException.expectMessage("Tensor address missing dimension(s): y, z"); + expectedException.expectMessage("Tensor address missing dimension(s) y, z"); validateTensorJson( TensorType.fromSpec("tensor(x[], y[], z[])"), @@ -211,7 +211,7 @@ public class ConstantTensorJsonValidatorTest { @Test public void ensure_that_extra_dimensions_are_disallowed() { expectedException.expect(InvalidConstantTensor.class); - expectedException.expectMessage("Tensor dimension \"z\" does not exist"); + expectedException.expectMessage("Tensor dimension 'z' does not exist"); validateTensorJson( TensorType.fromSpec("tensor(x[], y[])"), @@ -229,7 +229,7 @@ public class ConstantTensorJsonValidatorTest { @Test public void ensure_that_duplicate_dimensions_are_disallowed() { expectedException.expect(InvalidConstantTensor.class); - expectedException.expectMessage("Duplicate tensor dimension \"y\""); + expectedException.expectMessage("Duplicate tensor dimension 'y'"); validateTensorJson( TensorType.fromSpec("tensor(x[], y[])"), @@ -265,7 +265,7 @@ public class ConstantTensorJsonValidatorTest { @Test public void ensure_that_invalid_json_not_in_tensor_format_fails() { expectedException.expect(InvalidConstantTensor.class); - expectedException.expectMessage("Expected field name \"cells\", got \"stats\""); + expectedException.expectMessage("Expected field name 'cells', got 'stats'"); validateTensorJson(TensorType.fromSpec("tensor(x[], y[])"), inputJsonToReader( diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidatorTest.java index d99fd93d5eb..dbf6013b167 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidatorTest.java @@ -9,6 +9,7 @@ import org.junit.rules.ExpectedException; import static com.yahoo.vespa.model.application.validation.RankingConstantsValidator.TensorValidationFailed; public class RankingConstantsValidatorTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); @@ -20,10 +21,11 @@ public class RankingConstantsValidatorTest { @Test public void ensure_that_failing_ranking_constants_fails() { expectedException.expect(TensorValidationFailed.class); - expectedException.expectMessage("Ranking constant \"constant_tensor_2\" (tensors/constant_tensor_2.json): Tensor coordinate is not a string (VALUE_NUMBER_INT)"); - expectedException.expectMessage("Ranking constant \"constant_tensor_3\" (tensors/constant_tensor_3.json): Tensor dimension \"cd\" does not exist"); - expectedException.expectMessage("Ranking constant \"constant_tensor_4\" (tensors/constant_tensor_4.json): Tensor dimension \"z\" does not exist"); + expectedException.expectMessage("Ranking constant 'constant_tensor_2' (tensors/constant_tensor_2.json): Tensor label is not a string (VALUE_NUMBER_INT)"); + expectedException.expectMessage("Ranking constant 'constant_tensor_3' (tensors/constant_tensor_3.json): Tensor dimension 'cd' does not exist"); + expectedException.expectMessage("Ranking constant 'constant_tensor_4' (tensors/constant_tensor_4.json): Tensor dimension 'z' does not exist"); new VespaModelCreatorWithFilePkg("src/test/cfg/application/validation/ranking_constants_fail/").create(); } + } |