diff options
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 7c7391ff895..8de85c7a0b7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -169,6 +169,60 @@ public class JsonFormatTestCase { } @Test + public void testDenseInt8Tensor() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[2])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(7.0); + Tensor tensor = builder.build(); + + String shortJson = """ + { + "type":"tensor<int8>(x[2],y[2])", + "values":[[2,3],[5,7]] + } + """; + byte[] shortEncoded = JsonFormat.encode(tensor, true, false); + assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8)); + assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded)); + + String longJson = """ + { + "type":"tensor<int8>(x[2],y[2])", + "cells":[ + {"address":{"x":"0","y":"0"},"value":2}, + {"address":{"x":"0","y":"1"},"value":3}, + {"address":{"x":"1","y":"0"},"value":5}, + {"address":{"x":"1","y":"1"},"value":7} + ] + } + """; + byte[] longEncoded = JsonFormat.encode(tensor, false, false); + assertEqualJson(longJson, new String(longEncoded, StandardCharsets.UTF_8)); + assertEquals(tensor, JsonFormat.decode(tensor.type(), longEncoded)); + + String shortDirectJson = """ + [[2, 3], [5, 7]] + """; + byte[] shortDirectEncoded = JsonFormat.encode(tensor, true, true); + assertEqualJson(shortDirectJson, new String(shortDirectEncoded, StandardCharsets.UTF_8)); + assertEquals(tensor, JsonFormat.decode(tensor.type(), shortDirectEncoded)); + + String longDirectJson = """ + [ + {"address":{"x":"0","y":"0"},"value":2}, + {"address":{"x":"0","y":"1"},"value":3}, + {"address":{"x":"1","y":"0"},"value":5}, + {"address":{"x":"1","y":"1"},"value":7} + ] + """; + byte[] longDirectEncoded = JsonFormat.encode(tensor, false, true); + assertEqualJson(longDirectJson, new String(longDirectEncoded, StandardCharsets.UTF_8)); + assertEquals(tensor, JsonFormat.decode(tensor.type(), longDirectEncoded)); + } + + @Test public void testMixedTensor() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[2])")); builder.cell().label("x", "a").label("y", 0).value(2.0); |