diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 103 |
1 files changed, 99 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index fa2094e9d2a..9eb9cb06666 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -103,10 +103,17 @@ public class JsonFormat { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); + IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; + if (values.type() == Type.STRING) { + double[] decoded = decodeHexString(values.asString(), builder.type().valueType()); + for (int i = 0; i < decoded.length; i++) { + indexedBuilder.cellByDirectIndex(i, decoded[i]); + } + return; + } if ( values.type() != Type.ARRAY) throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type()); - IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; MutableInteger index = new MutableInteger(0); values.traverse((ArrayTraverser) (__, value) -> { if (value.type() != Type.LONG && value.type() != Type.DOUBLE) @@ -143,11 +150,99 @@ public class JsonFormat { decodeValues(value, mixedBuilder)); } + private static byte decodeHex(String input, int index) { + int d = Character.digit(input.charAt(index), 16); + if (d < 0) { + throw new IllegalArgumentException("Invalid digit '"+input.charAt(index)+"' at index "+index+" in input "+input); + } + return (byte)d; + } + + private static double[] decodeHexStringAsBytes(String input) { + int l = input.length() / 2; + double[] result = new double[l]; + int idx = 0; + for (int i = 0; i < l; i++) { + byte v = decodeHex(input, idx++); + v <<= 4; + v += decodeHex(input, idx++); + result[i] = v; + } + return result; + } + + private static double[] decodeHexStringAsBFloat16s(String input) { + int l = input.length() / 4; + double[] result = new double[l]; + int idx = 0; + for (int i = 0; i < l; i++) { + int v = decodeHex(input, idx++); + v <<= 4; v += decodeHex(input, idx++); + v <<= 4; v += decodeHex(input, idx++); + v <<= 4; v += decodeHex(input, idx++); + v <<= 16; + result[i] = Float.intBitsToFloat(v); + } + return result; + } + + private static double[] decodeHexStringAsFloats(String input) { + int l = input.length() / 8; + double[] result = new double[l]; + int idx = 0; + for (int i = 0; i < l; i++) { + int v = 0; + for (int j = 0; j < 8; j++) { + v <<= 4; + v += decodeHex(input, idx++); + } + result[i] = Float.intBitsToFloat(v); + } + return result; + } + + private static double[] decodeHexStringAsDoubles(String input) { + int l = input.length() / 16; + double[] result = new double[l]; + int idx = 0; + for (int i = 0; i < l; i++) { + long v = 0; + for (int j = 0; j < 16; j++) { + v <<= 4; + v += decodeHex(input, idx++); + } + result[i] = Double.longBitsToDouble(v); + } + return result; + } + + private static double[] decodeHexString(String input, TensorType.Value valueType) { + switch(valueType) { + case INT8: + return decodeHexStringAsBytes(input); + case BFLOAT16: + return decodeHexStringAsBFloat16s(input); + case FLOAT: + return decodeHexStringAsFloats(input); + case DOUBLE: + return decodeHexStringAsDoubles(input); + default: + throw new IllegalArgumentException("Cannot handle value type: "+valueType); + } + } + private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { - if (valuesField.type() != Type.ARRAY) - throw new IllegalArgumentException("Expected a block to contain a 'values' array"); double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; - valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); + if (valuesField.type() == Type.ARRAY) { + valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); + } else if (valuesField.type() == Type.STRING) { + double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType()); + for (int i = 0; i < decoded.length; i++) { + values[i] = decoded[i]; + } + } else { + throw new IllegalArgumentException("Expected a block to contain a 'values' array"); + } return values; } |