diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-03-13 19:18:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-13 19:18:13 +0100 |
commit | bccf4389dd29f34c26be1c6c1387096321d79323 (patch) | |
tree | c53ab8dada7b84c7c43e0e175b4fafd9ab88e76c | |
parent | 6c18071acebb4d783fd2dfb9b2ad49deb3b86c58 (diff) | |
parent | 23cfa731a3279f0d1fbf1e7f6a5f35cfa194ecc4 (diff) |
Merge pull request #26422 from vespa-engine/bratseth/int8-renderingv8.137.25
Value type aware value rendering
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 23 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 54 |
2 files changed, 72 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index b7e6e67ce73..45e581d73e8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -25,7 +25,6 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; /** * Writes tensors on the JSON format used in Vespa tensor document fields: @@ -113,14 +112,14 @@ public class JsonFormat { Tensor.Cell cell = i.next(); Cursor cellObject = cellsArray.addObject(); encodeAddress(tensor.type(), cell.getKey(), cellObject.setObject("address")); - cellObject.setDouble("value", cell.getValue()); + setValue("value", cell.getValue(), tensor.type().valueType(), cellObject); } } private static void encodeSingleDimensionCells(MappedTensor tensor, Cursor cells) { if (tensor.type().dimensions().size() > 1) throw new IllegalStateException("JSON encode of mapped tensor can only contain a single dimension"); - tensor.cells().forEach((k,v) -> cells.setDouble(k.label(0), v)); + tensor.cells().forEach((k,v) -> setValue(k.label(0), v, tensor.type().valueType(), cells)); } private static void encodeAddress(TensorType type, TensorAddress address, Cursor addressObject) { @@ -131,13 +130,13 @@ public class JsonFormat { private static void encodeValues(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) { DimensionSizes sizes = tensor.dimensionSizes(); if (indexes.length == 0) { - cursor.addDouble(tensor.get(0)); + addValue(tensor.get(0), tensor.type().valueType(), cursor); } else { for (indexes[dimension] = 0; indexes[dimension] < sizes.size(dimension); ++indexes[dimension]) { if (dimension < (sizes.dimensions() - 1)) { encodeValues(tensor, cursor.addArray(), indexes, dimension + 1); } else { - cursor.addDouble(tensor.get(indexes)); + addValue(tensor.get(indexes), tensor.type().valueType(), cursor); } } } @@ -174,6 +173,20 @@ public class JsonFormat { } } + private static void addValue(double value, TensorType.Value valueType, Cursor cursor) { + if (valueType == TensorType.Value.INT8) + cursor.addLong((long)value); + else + cursor.addDouble(value); + } + + private static void setValue(String field, double value, TensorType.Value valueType, Cursor cursor) { + if (valueType == TensorType.Value.INT8) + cursor.setLong(field, (long)value); + else + cursor.setDouble(field, value); + } + private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) { TensorAddress.Builder builder = new TensorAddress.Builder(subType); for (TensorType.Dimension dim : subType.dimensions()) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 7c7391ff895..8de85c7a0b7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -169,6 +169,60 @@ public class JsonFormatTestCase { } @Test + public void testDenseInt8Tensor() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[2])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(7.0); + Tensor tensor = builder.build(); + + String shortJson = """ + { + "type":"tensor<int8>(x[2],y[2])", + "values":[[2,3],[5,7]] + } + """; + byte[] shortEncoded = JsonFormat.encode(tensor, true, false); + assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8)); + assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded)); + + String longJson = """ + { + "type":"tensor<int8>(x[2],y[2])", + "cells":[ + {"address":{"x":"0","y":"0"},"value":2}, + {"address":{"x":"0","y":"1"},"value":3}, + {"address":{"x":"1","y":"0"},"value":5}, + {"address":{"x":"1","y":"1"},"value":7} + ] + } + """; + byte[] longEncoded = JsonFormat.encode(tensor, false, false); + assertEqualJson(longJson, new String(longEncoded, StandardCharsets.UTF_8)); + assertEquals(tensor, JsonFormat.decode(tensor.type(), longEncoded)); + + String shortDirectJson = """ + [[2, 3], [5, 7]] + """; + byte[] shortDirectEncoded = JsonFormat.encode(tensor, true, true); + assertEqualJson(shortDirectJson, new String(shortDirectEncoded, StandardCharsets.UTF_8)); + assertEquals(tensor, JsonFormat.decode(tensor.type(), shortDirectEncoded)); + + String longDirectJson = """ + [ + {"address":{"x":"0","y":"0"},"value":2}, + {"address":{"x":"0","y":"1"},"value":3}, + {"address":{"x":"1","y":"0"},"value":5}, + {"address":{"x":"1","y":"1"},"value":7} + ] + """; + byte[] longDirectEncoded = JsonFormat.encode(tensor, false, true); + assertEqualJson(longDirectJson, new String(longDirectEncoded, StandardCharsets.UTF_8)); + assertEquals(tensor, JsonFormat.decode(tensor.type(), longDirectEncoded)); + } + + @Test public void testMixedTensor() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[2])")); builder.cell().label("x", "a").label("y", 0).value(2.0); |