summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java67
1 files changed, 38 insertions, 29 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index a7afc1efc6d..0e8fbc30bb6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -60,25 +60,21 @@ public class JsonFormat {
Cursor root = slime.setObject();
root.setString("type", tensor.type().toString());
- // Encode as nested lists if indexed tensor
- if (tensor instanceof IndexedTensor) {
- IndexedTensor denseTensor = (IndexedTensor) tensor;
+ if (tensor instanceof IndexedTensor denseTensor) {
+ // Encode as nested lists if indexed tensor
encodeValues(denseTensor, root.setArray("values"), new long[denseTensor.dimensionSizes().dimensions()], 0);
}
-
- // Short form for a single mapped dimension
else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) {
+ // Short form for a single mapped dimension
encodeSingleDimensionCells((MappedTensor) tensor, root);
}
-
- // Short form for a mixed tensor
else if (tensor instanceof MixedTensor &&
tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() >= 1) {
+ // Short form for a mixed tensor
encodeBlocks((MixedTensor) tensor, root);
}
-
- // No other short forms exist: default to standard cell address output
else {
+ // No other short forms exist: default to standard cell address output
encodeCells(tensor, root);
}
@@ -177,17 +173,25 @@ public class JsonFormat {
Tensor.Builder builder = Tensor.Builder.of(type);
Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
- if (root.field("cells").valid())
+ if (root.field("cells").valid() && ! primitiveContent(root.field("cells")))
decodeCells(root.field("cells"), builder);
- else if (root.field("values").valid())
+ else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed()))
decodeValues(root.field("values"), builder);
else if (root.field("blocks").valid())
decodeBlocks(root.field("blocks"), builder);
- else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
- throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'");
+ else
+ decodeDirectValue(root, builder);
return builder.build();
}
+ private static boolean primitiveContent(Inspector cellsValue) {
+ if (cellsValue.type() == Type.DOUBLE) return true;
+ if (cellsValue.type() == Type.LONG) return true;
+ if (cellsValue.type() == Type.ARRAY && cellsValue.entries() > 0 &&
+ ( cellsValue.entry(0).type() == Type.DOUBLE || cellsValue.entry(0).type() == Type.LONG)) return true;
+ return false;
+ }
+
private static void decodeCells(Inspector cells, Tensor.Builder builder) {
if (cells.type() == Type.ARRAY)
cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder));
@@ -212,10 +216,9 @@ public class JsonFormat {
}
private static void decodeValues(Inspector values, Tensor.Builder builder) {
- if ( ! (builder instanceof IndexedTensor.BoundBuilder))
+ if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder))
throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
"Use 'cells' or 'blocks' instead");
- IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
if (values.type() == Type.STRING) {
double[] decoded = decodeHexString(values.asString(), builder.type().valueType());
if (decoded.length == 0)
@@ -240,10 +243,9 @@ public class JsonFormat {
}
private static void decodeBlocks(Inspector values, Tensor.Builder builder) {
- if ( ! (builder instanceof MixedTensor.BoundBuilder))
+ if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder))
throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " +
"Use 'cells' or 'values' instead");
- MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder;
if (values.type() == Type.ARRAY)
values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder));
@@ -260,6 +262,19 @@ public class JsonFormat {
decodeValues(block.field("values"), mixedBuilder));
}
+ /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */
+ private static void decodeDirectValue(Inspector root, Tensor.Builder builder) {
+ boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
+ boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
+
+ if ( ! hasMapped)
+ decodeValues(root, builder);
+ else if (hasMapped && hasIndexed)
+ decodeBlocks(root, builder);
+ else
+ decodeCells(root, builder);
+ }
+
private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) {
if (value.type() != Type.ARRAY)
throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type());
@@ -334,18 +349,12 @@ public class JsonFormat {
}
public static double[] decodeHexString(String input, TensorType.Value valueType) {
- switch(valueType) {
- case INT8:
- return decodeHexStringAsBytes(input);
- case BFLOAT16:
- return decodeHexStringAsBFloat16s(input);
- case FLOAT:
- return decodeHexStringAsFloats(input);
- case DOUBLE:
- return decodeHexStringAsDoubles(input);
- default:
- throw new IllegalArgumentException("Cannot handle value type: "+valueType);
- }
+ return switch (valueType) {
+ case INT8 -> decodeHexStringAsBytes(input);
+ case BFLOAT16 -> decodeHexStringAsBFloat16s(input);
+ case FLOAT -> decodeHexStringAsFloats(input);
+ case DOUBLE -> decodeHexStringAsDoubles(input);
+ };
}
private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {