aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-03-13 19:18:13 +0100
committerGitHub <noreply@github.com>2023-03-13 19:18:13 +0100
commitbccf4389dd29f34c26be1c6c1387096321d79323 (patch)
treec53ab8dada7b84c7c43e0e175b4fafd9ab88e76c
parent6c18071acebb4d783fd2dfb9b2ad49deb3b86c58 (diff)
parent23cfa731a3279f0d1fbf1e7f6a5f35cfa194ecc4 (diff)
Merge pull request #26422 from vespa-engine/bratseth/int8-renderingv8.137.25
Value type aware value rendering
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java23
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java54
2 files changed, 72 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index b7e6e67ce73..45e581d73e8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -25,7 +25,6 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
-import java.util.stream.Collectors;
/**
* Writes tensors on the JSON format used in Vespa tensor document fields:
@@ -113,14 +112,14 @@ public class JsonFormat {
Tensor.Cell cell = i.next();
Cursor cellObject = cellsArray.addObject();
encodeAddress(tensor.type(), cell.getKey(), cellObject.setObject("address"));
- cellObject.setDouble("value", cell.getValue());
+ setValue("value", cell.getValue(), tensor.type().valueType(), cellObject);
}
}
private static void encodeSingleDimensionCells(MappedTensor tensor, Cursor cells) {
if (tensor.type().dimensions().size() > 1)
throw new IllegalStateException("JSON encode of mapped tensor can only contain a single dimension");
- tensor.cells().forEach((k,v) -> cells.setDouble(k.label(0), v));
+ tensor.cells().forEach((k,v) -> setValue(k.label(0), v, tensor.type().valueType(), cells));
}
private static void encodeAddress(TensorType type, TensorAddress address, Cursor addressObject) {
@@ -131,13 +130,13 @@ public class JsonFormat {
private static void encodeValues(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) {
DimensionSizes sizes = tensor.dimensionSizes();
if (indexes.length == 0) {
- cursor.addDouble(tensor.get(0));
+ addValue(tensor.get(0), tensor.type().valueType(), cursor);
} else {
for (indexes[dimension] = 0; indexes[dimension] < sizes.size(dimension); ++indexes[dimension]) {
if (dimension < (sizes.dimensions() - 1)) {
encodeValues(tensor, cursor.addArray(), indexes, dimension + 1);
} else {
- cursor.addDouble(tensor.get(indexes));
+ addValue(tensor.get(indexes), tensor.type().valueType(), cursor);
}
}
}
@@ -174,6 +173,20 @@ public class JsonFormat {
}
}
+ private static void addValue(double value, TensorType.Value valueType, Cursor cursor) {
+ if (valueType == TensorType.Value.INT8)
+ cursor.addLong((long)value);
+ else
+ cursor.addDouble(value);
+ }
+
+ private static void setValue(String field, double value, TensorType.Value valueType, Cursor cursor) {
+ if (valueType == TensorType.Value.INT8)
+ cursor.setLong(field, (long)value);
+ else
+ cursor.setDouble(field, value);
+ }
+
private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) {
TensorAddress.Builder builder = new TensorAddress.Builder(subType);
for (TensorType.Dimension dim : subType.dimensions()) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 7c7391ff895..8de85c7a0b7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -169,6 +169,60 @@ public class JsonFormatTestCase {
}
@Test
+ public void testDenseInt8Tensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[2])"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 1).value(7.0);
+ Tensor tensor = builder.build();
+
+ String shortJson = """
+ {
+ "type":"tensor<int8>(x[2],y[2])",
+ "values":[[2,3],[5,7]]
+ }
+ """;
+ byte[] shortEncoded = JsonFormat.encode(tensor, true, false);
+ assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded));
+
+ String longJson = """
+ {
+ "type":"tensor<int8>(x[2],y[2])",
+ "cells":[
+ {"address":{"x":"0","y":"0"},"value":2},
+ {"address":{"x":"0","y":"1"},"value":3},
+ {"address":{"x":"1","y":"0"},"value":5},
+ {"address":{"x":"1","y":"1"},"value":7}
+ ]
+ }
+ """;
+ byte[] longEncoded = JsonFormat.encode(tensor, false, false);
+ assertEqualJson(longJson, new String(longEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longEncoded));
+
+ String shortDirectJson = """
+ [[2, 3], [5, 7]]
+ """;
+ byte[] shortDirectEncoded = JsonFormat.encode(tensor, true, true);
+ assertEqualJson(shortDirectJson, new String(shortDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortDirectEncoded));
+
+ String longDirectJson = """
+ [
+ {"address":{"x":"0","y":"0"},"value":2},
+ {"address":{"x":"0","y":"1"},"value":3},
+ {"address":{"x":"1","y":"0"},"value":5},
+ {"address":{"x":"1","y":"1"},"value":7}
+ ]
+ """;
+ byte[] longDirectEncoded = JsonFormat.encode(tensor, false, true);
+ assertEqualJson(longDirectJson, new String(longDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longDirectEncoded));
+ }
+
+ @Test
public void testMixedTensor() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[2])"));
builder.cell().label("x", "a").label("y", 0).value(2.0);