summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java103
1 files changed, 99 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index fa2094e9d2a..9eb9cb06666 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -103,10 +103,17 @@ public class JsonFormat {
if ( ! (builder instanceof IndexedTensor.BoundBuilder))
throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
"Use 'cells' or 'blocks' instead");
+ IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
+ if (values.type() == Type.STRING) {
+ double[] decoded = decodeHexString(values.asString(), builder.type().valueType());
+ for (int i = 0; i < decoded.length; i++) {
+ indexedBuilder.cellByDirectIndex(i, decoded[i]);
+ }
+ return;
+ }
if ( values.type() != Type.ARRAY)
throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type());
- IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
MutableInteger index = new MutableInteger(0);
values.traverse((ArrayTraverser) (__, value) -> {
if (value.type() != Type.LONG && value.type() != Type.DOUBLE)
@@ -143,11 +150,99 @@ public class JsonFormat {
decodeValues(value, mixedBuilder));
}
+ private static byte decodeHex(String input, int index) {
+ int d = Character.digit(input.charAt(index), 16);
+ if (d < 0) {
+ throw new IllegalArgumentException("Invalid digit '"+input.charAt(index)+"' at index "+index+" in input "+input);
+ }
+ return (byte)d;
+ }
+
+ private static double[] decodeHexStringAsBytes(String input) {
+ int l = input.length() / 2;
+ double[] result = new double[l];
+ int idx = 0;
+ for (int i = 0; i < l; i++) {
+ byte v = decodeHex(input, idx++);
+ v <<= 4;
+ v += decodeHex(input, idx++);
+ result[i] = v;
+ }
+ return result;
+ }
+
+ private static double[] decodeHexStringAsBFloat16s(String input) {
+ int l = input.length() / 4;
+ double[] result = new double[l];
+ int idx = 0;
+ for (int i = 0; i < l; i++) {
+ int v = decodeHex(input, idx++);
+ v <<= 4; v += decodeHex(input, idx++);
+ v <<= 4; v += decodeHex(input, idx++);
+ v <<= 4; v += decodeHex(input, idx++);
+ v <<= 16;
+ result[i] = Float.intBitsToFloat(v);
+ }
+ return result;
+ }
+
+ private static double[] decodeHexStringAsFloats(String input) {
+ int l = input.length() / 8;
+ double[] result = new double[l];
+ int idx = 0;
+ for (int i = 0; i < l; i++) {
+ int v = 0;
+ for (int j = 0; j < 8; j++) {
+ v <<= 4;
+ v += decodeHex(input, idx++);
+ }
+ result[i] = Float.intBitsToFloat(v);
+ }
+ return result;
+ }
+
+ private static double[] decodeHexStringAsDoubles(String input) {
+ int l = input.length() / 16;
+ double[] result = new double[l];
+ int idx = 0;
+ for (int i = 0; i < l; i++) {
+ long v = 0;
+ for (int j = 0; j < 16; j++) {
+ v <<= 4;
+ v += decodeHex(input, idx++);
+ }
+ result[i] = Double.longBitsToDouble(v);
+ }
+ return result;
+ }
+
+ private static double[] decodeHexString(String input, TensorType.Value valueType) {
+ switch(valueType) {
+ case INT8:
+ return decodeHexStringAsBytes(input);
+ case BFLOAT16:
+ return decodeHexStringAsBFloat16s(input);
+ case FLOAT:
+ return decodeHexStringAsFloats(input);
+ case DOUBLE:
+ return decodeHexStringAsDoubles(input);
+ default:
+ throw new IllegalArgumentException("Cannot handle value type: "+valueType);
+ }
+ }
+
private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
- if (valuesField.type() != Type.ARRAY)
- throw new IllegalArgumentException("Expected a block to contain a 'values' array");
double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
- valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value));
+ if (valuesField.type() == Type.ARRAY) {
+ valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value));
+ } else if (valuesField.type() == Type.STRING) {
+ double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType());
+ for (int i = 0; i < decoded.length; i++) {
+ values[i] = decoded[i];
+ }
+ } else {
+ throw new IllegalArgumentException("Expected a block to contain a 'values' array");
+ }
return values;
}