aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-09-02 15:59:15 +0200
committerLester Solbakken <lesters@oath.com>2021-09-02 15:59:15 +0200
commit691ddac6701fb0d45b0ef2bbd49e5ab99bcc6c17 (patch)
treec2982328f7351182e4d7451f4c326563f0b24ca7 /vespajlib
parent045321f2a1b2d00d33ccf7c64c2708b6e2c94667 (diff)
Disallow feeding empty indexed tensors
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java23
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java24
2 files changed, 41 insertions, 6 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 80b37e43c3d..cb7539d8565 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -94,17 +94,17 @@ public class JsonFormat {
else if (root.field("blocks").valid())
decodeBlocks(root.field("blocks"), builder);
else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
- throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'");
+ throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'");
return builder.build();
}
private static void decodeCells(Inspector cells, Tensor.Builder builder) {
- if ( cells.type() == Type.ARRAY)
+ if (cells.type() == Type.ARRAY)
cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder));
else if (cells.type() == Type.OBJECT)
cells.traverse((ObjectTraverser) (key, value) -> decodeSingleDimensionCell(key, value, builder));
else
- throw new IllegalArgumentException("Excepted 'cells' to contain an array or obejct, not " + cells.type());
+ throw new IllegalArgumentException("Excepted 'cells' to contain an array or object, not " + cells.type());
}
private static void decodeCell(Inspector cell, Tensor.Builder builder) {
@@ -128,18 +128,23 @@ public class JsonFormat {
IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
if (values.type() == Type.STRING) {
double[] decoded = decodeHexString(values.asString(), builder.type().valueType());
+ if (decoded.length == 0)
+ throw new IllegalArgumentException("The 'values' string does not contain any values");
for (int i = 0; i < decoded.length; i++) {
indexedBuilder.cellByDirectIndex(i, decoded[i]);
}
return;
}
- if ( values.type() != Type.ARRAY)
+ if (values.type() != Type.ARRAY)
throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type());
+ if (values.entries() == 0)
+ throw new IllegalArgumentException("The 'values' array does not contain any values");
MutableInteger index = new MutableInteger(0);
values.traverse((ArrayTraverser) (__, value) -> {
- if (value.type() != Type.LONG && value.type() != Type.DOUBLE)
+ if (value.type() != Type.LONG && value.type() != Type.DOUBLE) {
throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type());
+ }
indexedBuilder.cellByDirectIndex(index.next(), value.asDouble());
});
}
@@ -167,7 +172,7 @@ public class JsonFormat {
private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) {
if (value.type() != Type.ARRAY)
- throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an object, not " + value.type());
+ throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type());
mixedBuilder.block(asAddress(key, mixedBuilder.type().mappedSubtype()),
decodeValues(value, mixedBuilder));
}
@@ -256,9 +261,15 @@ public class JsonFormat {
private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
if (valuesField.type() == Type.ARRAY) {
+ if (valuesField.entries() == 0) {
+ throw new IllegalArgumentException("The 'block' value array does not contain any values");
+ }
valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value));
} else if (valuesField.type() == Type.STRING) {
double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType());
+ if (decoded.length == 0) {
+ throw new IllegalArgumentException("The 'block' value string does not contain any values");
+ }
for (int i = 0; i < decoded.length; i++) {
values[i] = decoded[i];
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 2f1e3be9299..87796501917 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -68,6 +68,21 @@ public class JsonFormatTestCase {
}
@Test
+ public void testDisallowedEmptyDenseTensor() {
+ TensorType type = TensorType.fromSpec("tensor(x[3])");
+ assertDecodeFails(type, "{\"values\":[]}", "The 'values' array does not contain any values");
+ assertDecodeFails(type, "{\"values\":\"\"}", "The 'values' string does not contain any values");
+ }
+
+ @Test
+ public void testDisallowedEmptyMixedTensor() {
+ TensorType type = TensorType.fromSpec("tensor(x{},y[3])");
+ assertDecodeFails(type, "{\"blocks\":{ \"a\": [] } }", "The 'block' value array does not contain any values");
+ assertDecodeFails(type, "{\"blocks\":[ {\"address\":{\"x\":\"a\"}, \"values\": [] } ] }",
+ "The 'block' value array does not contain any values");
+ }
+
+ @Test
public void testDenseTensorInDenseForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])"));
builder.cell().label("x", 0).label("y", 0).value(2.0);
@@ -304,4 +319,13 @@ public class JsonFormatTestCase {
assertEquals(expected, new String(json, StandardCharsets.UTF_8));
}
+ private void assertDecodeFails(TensorType type, String format, String msg) {
+ try {
+ Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8));
+ fail("Did not get exception as expected, decoded as: " + decoded);
+ } catch (IllegalArgumentException e) {
+ assertEquals(e.getMessage(), msg);
+ }
+ }
+
}