diff options
4 files changed, 70 insertions, 6 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 27426f584bd..a005a15042f 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -94,6 +94,8 @@ public class TensorReader { IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; if (buffer.currentToken() == JsonToken.VALUE_STRING) { double[] decoded = decodeHexString(buffer.currentText(), builder.type().valueType()); + if (decoded.length == 0) + throw new IllegalArgumentException("The 'values' string does not contain any values"); for (int i = 0; i < decoded.length; i++) { indexedBuilder.cellByDirectIndex(i, decoded[i]); } @@ -104,6 +106,8 @@ public class TensorReader { for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { indexedBuilder.cellByDirectIndex(index++, readDouble(buffer)); } + if (index == 0) + throw new IllegalArgumentException("The 'values' array does not contain any values"); expectCompositeEnd(buffer.currentToken()); } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index e50fd9734f7..7f4b420cf9a 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1282,6 +1282,22 @@ public class JsonReaderTestCase { } @Test + public void testDisallowedDenseTensorShortFormWithoutValues() { + assertCreatePutFails(inputJson("{ 'values': [] }"), "dense_tensor", + "The 'values' array does not contain any values"); + assertCreatePutFails(inputJson("{ 'values': '' }"), "dense_tensor", + "The 'values' string does not contain any values"); + } + + @Test + public void testDisallowedMixedTensorShortFormWithoutValues() { + assertCreatePutFails(inputJson("{\"blocks\":{ \"a\": [] } }"), + "mixed_tensor", "Expected 3 values, but got 0"); + assertCreatePutFails(inputJson("{\"blocks\":[ {\"address\":{\"x\":\"a\"}, \"values\": [] } ] }"), + "mixed_tensor", "Expected 3 values, but got 0"); + } + + @Test public void testParsingOfSparseTensorWithCells() { Tensor tensor = assertSparseTensorField("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}", createPutWithSparseTensor(inputJson("{", @@ -2029,4 +2045,13 @@ public class JsonReaderTestCase { new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); } + private void assertCreatePutFails(String tensor, String name, String msg) { + try { + createPutWithTensor(inputJson(tensor), name); + fail("Expected exception"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains(msg)); + } + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 80b37e43c3d..cb7539d8565 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -94,17 +94,17 @@ public class JsonFormat { else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'"); + throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'"); return builder.build(); } private static void decodeCells(Inspector cells, Tensor.Builder builder) { - if ( cells.type() == Type.ARRAY) + if (cells.type() == Type.ARRAY) cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder)); else if (cells.type() == Type.OBJECT) cells.traverse((ObjectTraverser) (key, value) -> decodeSingleDimensionCell(key, value, builder)); else - throw new IllegalArgumentException("Excepted 'cells' to contain an array or obejct, not " + cells.type()); + throw new IllegalArgumentException("Excepted 'cells' to contain an array or object, not " + cells.type()); } private static void decodeCell(Inspector cell, Tensor.Builder builder) { @@ -128,18 +128,23 @@ public class JsonFormat { IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; if (values.type() == Type.STRING) { double[] decoded = decodeHexString(values.asString(), builder.type().valueType()); + if (decoded.length == 0) + throw new IllegalArgumentException("The 'values' string does not contain any values"); for (int i = 0; i < decoded.length; i++) { indexedBuilder.cellByDirectIndex(i, decoded[i]); } return; } - if ( values.type() != Type.ARRAY) + if (values.type() != Type.ARRAY) throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type()); + if (values.entries() == 0) + throw new IllegalArgumentException("The 'values' array does not contain any values"); MutableInteger index = new MutableInteger(0); values.traverse((ArrayTraverser) (__, value) -> { - if (value.type() != Type.LONG && value.type() != Type.DOUBLE) + if (value.type() != Type.LONG && value.type() != Type.DOUBLE) { throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type()); + } indexedBuilder.cellByDirectIndex(index.next(), value.asDouble()); }); } @@ -167,7 +172,7 @@ public class JsonFormat { private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) { if (value.type() != Type.ARRAY) - throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an object, not " + value.type()); + throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type()); mixedBuilder.block(asAddress(key, mixedBuilder.type().mappedSubtype()), decodeValues(value, mixedBuilder)); } @@ -256,9 +261,15 @@ public class JsonFormat { private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; if (valuesField.type() == Type.ARRAY) { + if (valuesField.entries() == 0) { + throw new IllegalArgumentException("The 'block' value array does not contain any values"); + } valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); } else if (valuesField.type() == Type.STRING) { double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType()); + if (decoded.length == 0) { + throw new IllegalArgumentException("The 'block' value string does not contain any values"); + } for (int i = 0; i < decoded.length; i++) { values[i] = decoded[i]; } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 2f1e3be9299..87796501917 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -68,6 +68,21 @@ public class JsonFormatTestCase { } @Test + public void testDisallowedEmptyDenseTensor() { + TensorType type = TensorType.fromSpec("tensor(x[3])"); + assertDecodeFails(type, "{\"values\":[]}", "The 'values' array does not contain any values"); + assertDecodeFails(type, "{\"values\":\"\"}", "The 'values' string does not contain any values"); + } + + @Test + public void testDisallowedEmptyMixedTensor() { + TensorType type = TensorType.fromSpec("tensor(x{},y[3])"); + assertDecodeFails(type, "{\"blocks\":{ \"a\": [] } }", "The 'block' value array does not contain any values"); + assertDecodeFails(type, "{\"blocks\":[ {\"address\":{\"x\":\"a\"}, \"values\": [] } ] }", + "The 'block' value array does not contain any values"); + } + + @Test public void testDenseTensorInDenseForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])")); builder.cell().label("x", 0).label("y", 0).value(2.0); @@ -304,4 +319,13 @@ public class JsonFormatTestCase { assertEquals(expected, new String(json, StandardCharsets.UTF_8)); } + private void assertDecodeFails(TensorType type, String format, String msg) { + try { + Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8)); + fail("Did not get exception as expected, decoded as: " + decoded); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), msg); + } + } + } |