diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-05-04 08:36:39 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-04 08:36:39 +0200 |
commit | 605dd570c7c21c3df4a1027b9ba9876fdf3005dc (patch) | |
tree | e6f1f790f7a5e8fcf4bf57c16177995796341fb2 /vespajlib/src | |
parent | 448d9205aa15351fbb770a7fbfae88d5b33bbe3d (diff) | |
parent | 66e49b1201f9574bd42b208a56969db87716f2db (diff) |
Merge pull request #17657 from vespa-engine/arnej/add-hex-string-input-format
allow a string (with a hex dump of binary representation) as cell values
Diffstat (limited to 'vespajlib/src')
3 files changed, 226 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 4f0a08ac202..3133752bc49 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -368,6 +368,8 @@ public interface Tensor { } static boolean approxEquals(double x, double y, double tolerance) { + if (x == y) return true; + if (Double.isNaN(x) && Double.isNaN(y)) return true; return Math.abs(x-y) < tolerance; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index fa2094e9d2a..9eb9cb06666 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -103,10 +103,17 @@ public class JsonFormat { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); + IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; + if (values.type() == Type.STRING) { + double[] decoded = decodeHexString(values.asString(), builder.type().valueType()); + for (int i = 0; i < decoded.length; i++) { + indexedBuilder.cellByDirectIndex(i, decoded[i]); + } + return; + } if ( values.type() != Type.ARRAY) throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type()); - IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; MutableInteger index = new MutableInteger(0); values.traverse((ArrayTraverser) (__, value) -> { if (value.type() != Type.LONG && value.type() != Type.DOUBLE) @@ -143,11 +150,99 @@ public class JsonFormat { decodeValues(value, mixedBuilder)); } + private static byte decodeHex(String input, int index) { + int d = Character.digit(input.charAt(index), 16); + if (d < 0) { + throw new IllegalArgumentException("Invalid digit '"+input.charAt(index)+"' at index "+index+" in input "+input); + } + return (byte)d; + } + + private static double[] decodeHexStringAsBytes(String input) { + int l = input.length() / 2; + double[] result = new double[l]; + int idx = 0; + for (int i = 0; i < l; i++) { + byte v = decodeHex(input, idx++); + v <<= 4; + v += decodeHex(input, idx++); + result[i] = v; + } + return result; + } + + private static double[] decodeHexStringAsBFloat16s(String input) { + int l = input.length() / 4; + double[] result = new double[l]; + int idx = 0; + for (int i = 0; i < l; i++) { + int v = decodeHex(input, idx++); + v <<= 4; v += decodeHex(input, idx++); + v <<= 4; v += decodeHex(input, idx++); + v <<= 4; v += decodeHex(input, idx++); + v <<= 16; + result[i] = Float.intBitsToFloat(v); + } + return result; + } + + private static double[] decodeHexStringAsFloats(String input) { + int l = input.length() / 8; + double[] result = new double[l]; + int idx = 0; + for (int i = 0; i < l; i++) { + int v = 0; + for (int j = 0; j < 8; j++) { + v <<= 4; + v += decodeHex(input, idx++); + } + result[i] = Float.intBitsToFloat(v); + } + return result; + } + + private static double[] decodeHexStringAsDoubles(String input) { + int l = input.length() / 16; + double[] result = new double[l]; + int idx = 0; + for (int i = 0; i < l; i++) { + long v = 0; + for (int j = 0; j < 16; j++) { + v <<= 4; + v += decodeHex(input, idx++); + } + result[i] = Double.longBitsToDouble(v); + } + return result; + } + + private static double[] decodeHexString(String input, TensorType.Value valueType) { + switch(valueType) { + case INT8: + return decodeHexStringAsBytes(input); + case BFLOAT16: + return decodeHexStringAsBFloat16s(input); + case FLOAT: + return decodeHexStringAsFloats(input); + case DOUBLE: + return decodeHexStringAsDoubles(input); + default: + throw new IllegalArgumentException("Cannot handle value type: "+valueType); + } + } + private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { - if (valuesField.type() != Type.ARRAY) - throw new IllegalArgumentException("Expected a block to contain a 'values' array"); double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; - valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); + if (valuesField.type() == Type.ARRAY) { + valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); + } else if (valuesField.type() == Type.STRING) { + double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType()); + for (int i = 0; i < decoded.length; i++) { + values[i] = decoded[i]; + } + } else { + throw new IllegalArgumentException("Expected a block to contain a 'values' array"); + } return values; } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 3ca20661587..011c4b1fe12 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -82,6 +82,131 @@ public class JsonFormatTestCase { } @Test + public void testInt8VectorInHexForm() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(127.0); + builder.cell().label("x", 0).label("y", 2).value(-1.0); + builder.cell().label("x", 1).label("y", 0).value(-128.0); + builder.cell().label("x", 1).label("y", 1).value(0.0); + builder.cell().label("x", 1).label("y", 2).value(42.0); + Tensor expected = builder.build(); + String denseJson = "{\"values\":\"027FFF80002A\"}"; + Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, decoded); + } + + @Test + public void testInt8VectorInvalidHex() { + var type = TensorType.fromSpec("tensor<int8>(x[2])"); + String denseJson = "{\"values\":\"abXc\"}"; + try { + Tensor decoded = JsonFormat.decode(type, denseJson.getBytes(StandardCharsets.UTF_8)); + fail("did not get exception as expected, decoded as: "+decoded); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), "Invalid digit 'X' at index 2 in input abXc"); + } + } + + @Test + public void testMixedInt8TensorWithHexForm() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x{},y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 0).label("y", 2).value(4.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(6.0); + builder.cell().label("x", 1).label("y", 2).value(7.0); + Tensor expected = builder.build(); + String mixedJson = "{\"blocks\":[" + + "{\"address\":{\"x\":\"0\"},\"values\":\"020304\"}," + + "{\"address\":{\"x\":\"1\"},\"values\":\"050607\"}" + + "]}"; + Tensor decoded = JsonFormat.decode(expected.type(), mixedJson.getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, decoded); + } + + @Test + public void testBFloat16VectorInHexForm() { + var builder = Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[3],y[4])")); + builder.cell().label("x", 0).label("y", 0).value(42.0); + builder.cell().label("x", 0).label("y", 1).value(1048576.0); + builder.cell().label("x", 0).label("y", 2).value(0.00000095367431640625); + builder.cell().label("x", 0).label("y", 3).value(-255.00); + + builder.cell().label("x", 1).label("y", 0).value(0.0); + builder.cell().label("x", 1).label("y", 1).value(-0.0); + builder.cell().label("x", 1).label("y", 2).value(Float.MIN_NORMAL); + builder.cell().label("x", 1).label("y", 3).value(0x1.feP+127); + + builder.cell().label("x", 2).label("y", 0).value(Float.POSITIVE_INFINITY); + builder.cell().label("x", 2).label("y", 1).value(Float.NEGATIVE_INFINITY); + builder.cell().label("x", 2).label("y", 2).value(Float.NaN); + builder.cell().label("x", 2).label("y", 3).value(-Float.NaN); + Tensor expected = builder.build(); + + String denseJson = "{\"values\":\"422849803580c37f0000800000807f7f7f80ff807fc0ffc0\"}"; + Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, decoded); + } + + @Test + public void testFloatVectorInHexForm() { + var builder = Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[3],y[4])")); + builder.cell().label("x", 0).label("y", 0).value(42.0); + builder.cell().label("x", 0).label("y", 1).value(1048577.0); + builder.cell().label("x", 0).label("y", 2).value(0.00000095367431640625); + builder.cell().label("x", 0).label("y", 3).value(-255.00); + + builder.cell().label("x", 1).label("y", 0).value(0.0); + builder.cell().label("x", 1).label("y", 1).value(-0.0); + builder.cell().label("x", 1).label("y", 2).value(Float.MIN_VALUE); + builder.cell().label("x", 1).label("y", 3).value(Float.MAX_VALUE); + + builder.cell().label("x", 2).label("y", 0).value(Float.POSITIVE_INFINITY); + builder.cell().label("x", 2).label("y", 1).value(Float.NEGATIVE_INFINITY); + builder.cell().label("x", 2).label("y", 2).value(Float.NaN); + builder.cell().label("x", 2).label("y", 3).value(-Float.NaN); + Tensor expected = builder.build(); + + String denseJson = "{\"values\":\"" + +"42280000"+"49800008"+"35800000"+"c37f0000" + +"00000000"+"80000000"+"00000001"+"7f7fffff" + +"7f800000"+"ff800000"+"7fc00000"+"ffc00000" + +"\"}"; + Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, decoded); + } + + @Test + public void testDoubleVectorInHexForm() { + var builder = Tensor.Builder.of(TensorType.fromSpec("tensor<double>(x[3],y[4])")); + builder.cell().label("x", 0).label("y", 0).value(42.0); + builder.cell().label("x", 0).label("y", 1).value(1048577.0); + builder.cell().label("x", 0).label("y", 2).value(0.00000095367431640625); + builder.cell().label("x", 0).label("y", 3).value(-255.00); + + builder.cell().label("x", 1).label("y", 0).value(0.0); + builder.cell().label("x", 1).label("y", 1).value(-0.0); + builder.cell().label("x", 1).label("y", 2).value(Double.MIN_VALUE); + builder.cell().label("x", 1).label("y", 3).value(Double.MAX_VALUE); + + builder.cell().label("x", 2).label("y", 0).value(Double.POSITIVE_INFINITY); + builder.cell().label("x", 2).label("y", 1).value(Double.NEGATIVE_INFINITY); + builder.cell().label("x", 2).label("y", 2).value(Double.NaN); + builder.cell().label("x", 2).label("y", 3).value(-Double.NaN); + Tensor expected = builder.build(); + + String denseJson = "{\"values\":\"" + +"4045000000000000"+"4130000100000000"+"3eb0000000000000"+"c06fe00000000000" + +"0000000000000000"+"8000000000000000"+"0000000000000001"+"7fefffffffffffff" + +"7ff0000000000000"+"fff0000000000000"+"7ff8000000000000"+"fff8000000000000" + +"\"}"; + Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, decoded); + } + + @Test public void testMixedTensorInMixedForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])")); builder.cell().label("x", 0).label("y", 0).value(2.0); |