diff options
author | Lester Solbakken <lesters@oath.com> | 2021-09-02 15:59:15 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-09-02 15:59:15 +0200 |
commit | 691ddac6701fb0d45b0ef2bbd49e5ab99bcc6c17 (patch) | |
tree | c2982328f7351182e4d7451f4c326563f0b24ca7 /vespajlib | |
parent | 045321f2a1b2d00d33ccf7c64c2708b6e2c94667 (diff) |
Disallow feeding empty indexed tensors
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 23 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 24 |
2 files changed, 41 insertions, 6 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 80b37e43c3d..cb7539d8565 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -94,17 +94,17 @@ public class JsonFormat { else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'"); + throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'"); return builder.build(); } private static void decodeCells(Inspector cells, Tensor.Builder builder) { - if ( cells.type() == Type.ARRAY) + if (cells.type() == Type.ARRAY) cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder)); else if (cells.type() == Type.OBJECT) cells.traverse((ObjectTraverser) (key, value) -> decodeSingleDimensionCell(key, value, builder)); else - throw new IllegalArgumentException("Excepted 'cells' to contain an array or obejct, not " + cells.type()); + throw new IllegalArgumentException("Excepted 'cells' to contain an array or object, not " + cells.type()); } private static void decodeCell(Inspector cell, Tensor.Builder builder) { @@ -128,18 +128,23 @@ public class JsonFormat { IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; if (values.type() == Type.STRING) { double[] decoded = decodeHexString(values.asString(), builder.type().valueType()); + if (decoded.length == 0) + throw new IllegalArgumentException("The 'values' string does not contain any values"); for (int i = 0; i < decoded.length; i++) { indexedBuilder.cellByDirectIndex(i, decoded[i]); } return; } - if ( values.type() != Type.ARRAY) + if (values.type() != Type.ARRAY) throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type()); + if (values.entries() == 0) + throw new IllegalArgumentException("The 'values' array does not contain any values"); MutableInteger index = new MutableInteger(0); values.traverse((ArrayTraverser) (__, value) -> { - if (value.type() != Type.LONG && value.type() != Type.DOUBLE) + if (value.type() != Type.LONG && value.type() != Type.DOUBLE) { throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type()); + } indexedBuilder.cellByDirectIndex(index.next(), value.asDouble()); }); } @@ -167,7 +172,7 @@ public class JsonFormat { private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) { if (value.type() != Type.ARRAY) - throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an object, not " + value.type()); + throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type()); mixedBuilder.block(asAddress(key, mixedBuilder.type().mappedSubtype()), decodeValues(value, mixedBuilder)); } @@ -256,9 +261,15 @@ public class JsonFormat { private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; if (valuesField.type() == Type.ARRAY) { + if (valuesField.entries() == 0) { + throw new IllegalArgumentException("The 'block' value array does not contain any values"); + } valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); } else if (valuesField.type() == Type.STRING) { double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType()); + if (decoded.length == 0) { + throw new IllegalArgumentException("The 'block' value string does not contain any values"); + } for (int i = 0; i < decoded.length; i++) { values[i] = decoded[i]; } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 2f1e3be9299..87796501917 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -68,6 +68,21 @@ public class JsonFormatTestCase { } @Test + public void testDisallowedEmptyDenseTensor() { + TensorType type = TensorType.fromSpec("tensor(x[3])"); + assertDecodeFails(type, "{\"values\":[]}", "The 'values' array does not contain any values"); + assertDecodeFails(type, "{\"values\":\"\"}", "The 'values' string does not contain any values"); + } + + @Test + public void testDisallowedEmptyMixedTensor() { + TensorType type = TensorType.fromSpec("tensor(x{},y[3])"); + assertDecodeFails(type, "{\"blocks\":{ \"a\": [] } }", "The 'block' value array does not contain any values"); + assertDecodeFails(type, "{\"blocks\":[ {\"address\":{\"x\":\"a\"}, \"values\": [] } ] }", + "The 'block' value array does not contain any values"); + } + + @Test public void testDenseTensorInDenseForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])")); builder.cell().label("x", 0).label("y", 0).value(2.0); @@ -304,4 +319,13 @@ public class JsonFormatTestCase { assertEquals(expected, new String(json, StandardCharsets.UTF_8)); } + private void assertDecodeFails(TensorType type, String format, String msg) { + try { + Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8)); + fail("Did not get exception as expected, decoded as: " + decoded); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), msg); + } + } + } |