aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-29 11:12:38 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-29 11:12:38 +0000
commit66e49b1201f9574bd42b208a56969db87716f2db (patch)
tree17f71449562b09096537346b52ce210e48bbe33e /vespajlib
parent126d4f78c4464c79f3365c433890119306693102 (diff)
allow a string (with a hex dump of binary representation) as cell values
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java103
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java125
3 files changed, 226 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 3378520dc91..7f9b2a67376 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -365,6 +365,8 @@ public interface Tensor {
}
static boolean approxEquals(double x, double y, double tolerance) {
+ if (x == y) return true;
+ if (Double.isNaN(x) && Double.isNaN(y)) return true;
return Math.abs(x-y) < tolerance;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index fa2094e9d2a..9eb9cb06666 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -103,10 +103,17 @@ public class JsonFormat {
if ( ! (builder instanceof IndexedTensor.BoundBuilder))
throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
"Use 'cells' or 'blocks' instead");
+ IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
+ if (values.type() == Type.STRING) {
+ double[] decoded = decodeHexString(values.asString(), builder.type().valueType());
+ for (int i = 0; i < decoded.length; i++) {
+ indexedBuilder.cellByDirectIndex(i, decoded[i]);
+ }
+ return;
+ }
if ( values.type() != Type.ARRAY)
throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type());
- IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
MutableInteger index = new MutableInteger(0);
values.traverse((ArrayTraverser) (__, value) -> {
if (value.type() != Type.LONG && value.type() != Type.DOUBLE)
@@ -143,11 +150,99 @@ public class JsonFormat {
decodeValues(value, mixedBuilder));
}
+ private static byte decodeHex(String input, int index) {
+ int d = Character.digit(input.charAt(index), 16);
+ if (d < 0) {
+ throw new IllegalArgumentException("Invalid digit '"+input.charAt(index)+"' at index "+index+" in input "+input);
+ }
+ return (byte)d;
+ }
+
+ private static double[] decodeHexStringAsBytes(String input) {
+ int l = input.length() / 2;
+ double[] result = new double[l];
+ int idx = 0;
+ for (int i = 0; i < l; i++) {
+ byte v = decodeHex(input, idx++);
+ v <<= 4;
+ v += decodeHex(input, idx++);
+ result[i] = v;
+ }
+ return result;
+ }
+
+ private static double[] decodeHexStringAsBFloat16s(String input) {
+ int l = input.length() / 4;
+ double[] result = new double[l];
+ int idx = 0;
+ for (int i = 0; i < l; i++) {
+ int v = decodeHex(input, idx++);
+ v <<= 4; v += decodeHex(input, idx++);
+ v <<= 4; v += decodeHex(input, idx++);
+ v <<= 4; v += decodeHex(input, idx++);
+ v <<= 16;
+ result[i] = Float.intBitsToFloat(v);
+ }
+ return result;
+ }
+
+ private static double[] decodeHexStringAsFloats(String input) {
+ int l = input.length() / 8;
+ double[] result = new double[l];
+ int idx = 0;
+ for (int i = 0; i < l; i++) {
+ int v = 0;
+ for (int j = 0; j < 8; j++) {
+ v <<= 4;
+ v += decodeHex(input, idx++);
+ }
+ result[i] = Float.intBitsToFloat(v);
+ }
+ return result;
+ }
+
+ private static double[] decodeHexStringAsDoubles(String input) {
+ int l = input.length() / 16;
+ double[] result = new double[l];
+ int idx = 0;
+ for (int i = 0; i < l; i++) {
+ long v = 0;
+ for (int j = 0; j < 16; j++) {
+ v <<= 4;
+ v += decodeHex(input, idx++);
+ }
+ result[i] = Double.longBitsToDouble(v);
+ }
+ return result;
+ }
+
+ private static double[] decodeHexString(String input, TensorType.Value valueType) {
+ switch(valueType) {
+ case INT8:
+ return decodeHexStringAsBytes(input);
+ case BFLOAT16:
+ return decodeHexStringAsBFloat16s(input);
+ case FLOAT:
+ return decodeHexStringAsFloats(input);
+ case DOUBLE:
+ return decodeHexStringAsDoubles(input);
+ default:
+ throw new IllegalArgumentException("Cannot handle value type: "+valueType);
+ }
+ }
+
private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
- if (valuesField.type() != Type.ARRAY)
- throw new IllegalArgumentException("Expected a block to contain a 'values' array");
double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
- valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value));
+ if (valuesField.type() == Type.ARRAY) {
+ valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value));
+ } else if (valuesField.type() == Type.STRING) {
+ double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType());
+ for (int i = 0; i < decoded.length; i++) {
+ values[i] = decoded[i];
+ }
+ } else {
+ throw new IllegalArgumentException("Expected a block to contain a 'values' array");
+ }
return values;
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 3ca20661587..011c4b1fe12 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -82,6 +82,131 @@ public class JsonFormatTestCase {
}
@Test
+ public void testInt8VectorInHexForm() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(127.0);
+ builder.cell().label("x", 0).label("y", 2).value(-1.0);
+ builder.cell().label("x", 1).label("y", 0).value(-128.0);
+ builder.cell().label("x", 1).label("y", 1).value(0.0);
+ builder.cell().label("x", 1).label("y", 2).value(42.0);
+ Tensor expected = builder.build();
+ String denseJson = "{\"values\":\"027FFF80002A\"}";
+ Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8));
+ assertEquals(expected, decoded);
+ }
+
+ @Test
+ public void testInt8VectorInvalidHex() {
+ var type = TensorType.fromSpec("tensor<int8>(x[2])");
+ String denseJson = "{\"values\":\"abXc\"}";
+ try {
+ Tensor decoded = JsonFormat.decode(type, denseJson.getBytes(StandardCharsets.UTF_8));
+ fail("did not get exception as expected, decoded as: "+decoded);
+ } catch (IllegalArgumentException e) {
+ assertEquals(e.getMessage(), "Invalid digit 'X' at index 2 in input abXc");
+ }
+ }
+
+ @Test
+ public void testMixedInt8TensorWithHexForm() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x{},y[3])"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("x", 0).label("y", 2).value(4.0);
+ builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 1).value(6.0);
+ builder.cell().label("x", 1).label("y", 2).value(7.0);
+ Tensor expected = builder.build();
+ String mixedJson = "{\"blocks\":[" +
+ "{\"address\":{\"x\":\"0\"},\"values\":\"020304\"}," +
+ "{\"address\":{\"x\":\"1\"},\"values\":\"050607\"}" +
+ "]}";
+ Tensor decoded = JsonFormat.decode(expected.type(), mixedJson.getBytes(StandardCharsets.UTF_8));
+ assertEquals(expected, decoded);
+ }
+
+ @Test
+ public void testBFloat16VectorInHexForm() {
+ var builder = Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[3],y[4])"));
+ builder.cell().label("x", 0).label("y", 0).value(42.0);
+ builder.cell().label("x", 0).label("y", 1).value(1048576.0);
+ builder.cell().label("x", 0).label("y", 2).value(0.00000095367431640625);
+ builder.cell().label("x", 0).label("y", 3).value(-255.00);
+
+ builder.cell().label("x", 1).label("y", 0).value(0.0);
+ builder.cell().label("x", 1).label("y", 1).value(-0.0);
+ builder.cell().label("x", 1).label("y", 2).value(Float.MIN_NORMAL);
+ builder.cell().label("x", 1).label("y", 3).value(0x1.feP+127);
+
+ builder.cell().label("x", 2).label("y", 0).value(Float.POSITIVE_INFINITY);
+ builder.cell().label("x", 2).label("y", 1).value(Float.NEGATIVE_INFINITY);
+ builder.cell().label("x", 2).label("y", 2).value(Float.NaN);
+ builder.cell().label("x", 2).label("y", 3).value(-Float.NaN);
+ Tensor expected = builder.build();
+
+ String denseJson = "{\"values\":\"422849803580c37f0000800000807f7f7f80ff807fc0ffc0\"}";
+ Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8));
+ assertEquals(expected, decoded);
+ }
+
+ @Test
+ public void testFloatVectorInHexForm() {
+ var builder = Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[3],y[4])"));
+ builder.cell().label("x", 0).label("y", 0).value(42.0);
+ builder.cell().label("x", 0).label("y", 1).value(1048577.0);
+ builder.cell().label("x", 0).label("y", 2).value(0.00000095367431640625);
+ builder.cell().label("x", 0).label("y", 3).value(-255.00);
+
+ builder.cell().label("x", 1).label("y", 0).value(0.0);
+ builder.cell().label("x", 1).label("y", 1).value(-0.0);
+ builder.cell().label("x", 1).label("y", 2).value(Float.MIN_VALUE);
+ builder.cell().label("x", 1).label("y", 3).value(Float.MAX_VALUE);
+
+ builder.cell().label("x", 2).label("y", 0).value(Float.POSITIVE_INFINITY);
+ builder.cell().label("x", 2).label("y", 1).value(Float.NEGATIVE_INFINITY);
+ builder.cell().label("x", 2).label("y", 2).value(Float.NaN);
+ builder.cell().label("x", 2).label("y", 3).value(-Float.NaN);
+ Tensor expected = builder.build();
+
+ String denseJson = "{\"values\":\""
+ +"42280000"+"49800008"+"35800000"+"c37f0000"
+ +"00000000"+"80000000"+"00000001"+"7f7fffff"
+ +"7f800000"+"ff800000"+"7fc00000"+"ffc00000"
+ +"\"}";
+ Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8));
+ assertEquals(expected, decoded);
+ }
+
+ @Test
+ public void testDoubleVectorInHexForm() {
+ var builder = Tensor.Builder.of(TensorType.fromSpec("tensor<double>(x[3],y[4])"));
+ builder.cell().label("x", 0).label("y", 0).value(42.0);
+ builder.cell().label("x", 0).label("y", 1).value(1048577.0);
+ builder.cell().label("x", 0).label("y", 2).value(0.00000095367431640625);
+ builder.cell().label("x", 0).label("y", 3).value(-255.00);
+
+ builder.cell().label("x", 1).label("y", 0).value(0.0);
+ builder.cell().label("x", 1).label("y", 1).value(-0.0);
+ builder.cell().label("x", 1).label("y", 2).value(Double.MIN_VALUE);
+ builder.cell().label("x", 1).label("y", 3).value(Double.MAX_VALUE);
+
+ builder.cell().label("x", 2).label("y", 0).value(Double.POSITIVE_INFINITY);
+ builder.cell().label("x", 2).label("y", 1).value(Double.NEGATIVE_INFINITY);
+ builder.cell().label("x", 2).label("y", 2).value(Double.NaN);
+ builder.cell().label("x", 2).label("y", 3).value(-Double.NaN);
+ Tensor expected = builder.build();
+
+ String denseJson = "{\"values\":\""
+ +"4045000000000000"+"4130000100000000"+"3eb0000000000000"+"c06fe00000000000"
+ +"0000000000000000"+"8000000000000000"+"0000000000000001"+"7fefffffffffffff"
+ +"7ff0000000000000"+"fff0000000000000"+"7ff8000000000000"+"fff8000000000000"
+ +"\"}";
+ Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8));
+ assertEquals(expected, decoded);
+ }
+
+ @Test
public void testMixedTensorInMixedForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])"));
builder.cell().label("x", 0).label("y", 0).value(2.0);