summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java54
1 files changed, 54 insertions, 0 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 7c7391ff895..8de85c7a0b7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -169,6 +169,60 @@ public class JsonFormatTestCase {
}
@Test
+ public void testDenseInt8Tensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[2])"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 1).value(7.0);
+ Tensor tensor = builder.build();
+
+ String shortJson = """
+ {
+ "type":"tensor<int8>(x[2],y[2])",
+ "values":[[2,3],[5,7]]
+ }
+ """;
+ byte[] shortEncoded = JsonFormat.encode(tensor, true, false);
+ assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded));
+
+ String longJson = """
+ {
+ "type":"tensor<int8>(x[2],y[2])",
+ "cells":[
+ {"address":{"x":"0","y":"0"},"value":2},
+ {"address":{"x":"0","y":"1"},"value":3},
+ {"address":{"x":"1","y":"0"},"value":5},
+ {"address":{"x":"1","y":"1"},"value":7}
+ ]
+ }
+ """;
+ byte[] longEncoded = JsonFormat.encode(tensor, false, false);
+ assertEqualJson(longJson, new String(longEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longEncoded));
+
+ String shortDirectJson = """
+ [[2, 3], [5, 7]]
+ """;
+ byte[] shortDirectEncoded = JsonFormat.encode(tensor, true, true);
+ assertEqualJson(shortDirectJson, new String(shortDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortDirectEncoded));
+
+ String longDirectJson = """
+ [
+ {"address":{"x":"0","y":"0"},"value":2},
+ {"address":{"x":"0","y":"1"},"value":3},
+ {"address":{"x":"1","y":"0"},"value":5},
+ {"address":{"x":"1","y":"1"},"value":7}
+ ]
+ """;
+ byte[] longDirectEncoded = JsonFormat.encode(tensor, false, true);
+ assertEqualJson(longDirectJson, new String(longDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longDirectEncoded));
+ }
+
+ @Test
public void testMixedTensor() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[2])"));
builder.cell().label("x", "a").label("y", 0).value(2.0);