summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-09-02 15:59:15 +0200
committerLester Solbakken <lesters@oath.com>2021-09-02 15:59:15 +0200
commit691ddac6701fb0d45b0ef2bbd49e5ab99bcc6c17 (patch)
treec2982328f7351182e4d7451f4c326563f0b24ca7 /vespajlib/src/main/java/com/yahoo/tensor
parent045321f2a1b2d00d33ccf7c64c2708b6e2c94667 (diff)
Disallow feeding empty indexed tensors
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java23
1 files changed, 17 insertions, 6 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 80b37e43c3d..cb7539d8565 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -94,17 +94,17 @@ public class JsonFormat {
else if (root.field("blocks").valid())
decodeBlocks(root.field("blocks"), builder);
else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
- throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'");
+ throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'");
return builder.build();
}
private static void decodeCells(Inspector cells, Tensor.Builder builder) {
- if ( cells.type() == Type.ARRAY)
+ if (cells.type() == Type.ARRAY)
cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder));
else if (cells.type() == Type.OBJECT)
cells.traverse((ObjectTraverser) (key, value) -> decodeSingleDimensionCell(key, value, builder));
else
- throw new IllegalArgumentException("Excepted 'cells' to contain an array or obejct, not " + cells.type());
+ throw new IllegalArgumentException("Excepted 'cells' to contain an array or object, not " + cells.type());
}
private static void decodeCell(Inspector cell, Tensor.Builder builder) {
@@ -128,18 +128,23 @@ public class JsonFormat {
IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
if (values.type() == Type.STRING) {
double[] decoded = decodeHexString(values.asString(), builder.type().valueType());
+ if (decoded.length == 0)
+ throw new IllegalArgumentException("The 'values' string does not contain any values");
for (int i = 0; i < decoded.length; i++) {
indexedBuilder.cellByDirectIndex(i, decoded[i]);
}
return;
}
- if ( values.type() != Type.ARRAY)
+ if (values.type() != Type.ARRAY)
throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type());
+ if (values.entries() == 0)
+ throw new IllegalArgumentException("The 'values' array does not contain any values");
MutableInteger index = new MutableInteger(0);
values.traverse((ArrayTraverser) (__, value) -> {
- if (value.type() != Type.LONG && value.type() != Type.DOUBLE)
+ if (value.type() != Type.LONG && value.type() != Type.DOUBLE) {
throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type());
+ }
indexedBuilder.cellByDirectIndex(index.next(), value.asDouble());
});
}
@@ -167,7 +172,7 @@ public class JsonFormat {
private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) {
if (value.type() != Type.ARRAY)
- throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an object, not " + value.type());
+ throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type());
mixedBuilder.block(asAddress(key, mixedBuilder.type().mappedSubtype()),
decodeValues(value, mixedBuilder));
}
@@ -256,9 +261,15 @@ public class JsonFormat {
private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
if (valuesField.type() == Type.ARRAY) {
+ if (valuesField.entries() == 0) {
+ throw new IllegalArgumentException("The 'block' value array does not contain any values");
+ }
valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value));
} else if (valuesField.type() == Type.STRING) {
double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType());
+ if (decoded.length == 0) {
+ throw new IllegalArgumentException("The 'block' value string does not contain any values");
+ }
for (int i = 0; i < decoded.length; i++) {
values[i] = decoded[i];
}