diff options
author | Lester Solbakken <lesters@oath.com> | 2021-09-01 15:26:11 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-09-01 15:26:11 +0200 |
commit | 8286b1d9394a3f89c08b0d193d65d44e937be017 (patch) | |
tree | ae0f97745055aba39feae4ba59b9a3594a1a9b01 /vespajlib | |
parent | b8ef0eedfd29827a98d561d65b4c657ecbadf243 (diff) |
Add short form output option to model-evaluation REST API
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 22 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 30 |
2 files changed, 52 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 461e73e3611..80b37e43c3d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -9,6 +9,7 @@ import com.yahoo.slime.JsonDecoder; import com.yahoo.slime.ObjectTraverser; import com.yahoo.slime.Slime; import com.yahoo.slime.Type; +import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; @@ -44,6 +45,16 @@ public class JsonFormat { return com.yahoo.slime.JsonFormat.toJsonBytes(slime); } + /** Serializes the given tensor type and value into a short-form JSON format */ + public static byte[] encodeShortForm(IndexedTensor tensor) { + Slime slime = new Slime(); + Cursor root = slime.setObject(); + root.setString("type", tensor.type().toString()); + Cursor value = root.setArray("value"); + encodeList(tensor, value, new long[tensor.dimensionSizes().dimensions()], 0); + return com.yahoo.slime.JsonFormat.toJsonBytes(slime); + } + private static void encodeCells(Tensor tensor, Cursor rootObject) { Cursor cellsArray = rootObject.setArray("cells"); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { @@ -59,6 +70,17 @@ public class JsonFormat { addressObject.setString(type.dimensions().get(i).name(), address.label(i)); } + private static void encodeList(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) { + DimensionSizes sizes = tensor.dimensionSizes(); + for (indexes[dimension] = 0; indexes[dimension] < sizes.size(dimension); ++indexes[dimension]) { + if (dimension < (sizes.dimensions() - 1)) { + encodeList(tensor, cursor.addArray(), indexes, dimension + 1); + } else { + cursor.addDouble(tensor.get(indexes)); + } + } + } + /** Deserializes the given tensor from JSON format */ // NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module public static Tensor decode(TensorType type, byte[] jsonTensorValue) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 011c4b1fe12..2f1e3be9299 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -82,6 +83,30 @@ public class JsonFormatTestCase { } @Test + public void testDenseTensorShortForm() { + assertEncodeShortForm("tensor(x[]):[1.0, 2.0]", + "{\"type\":\"tensor(x[])\",\"value\":[1.0,2.0]}"); + assertEncodeShortForm("tensor<float>(x[]):[1.0, 2.0]", + "{\"type\":\"tensor<float>(x[])\",\"value\":[1.0,2.0]}"); + assertEncodeShortForm("tensor(x[],y[]):[[1,2,3,4]]", + "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0,3.0,4.0]]}"); + assertEncodeShortForm("tensor(x[],y[]):[[1,2],[3,4]]", + "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0],[3.0,4.0]]}"); + assertEncodeShortForm("tensor(x[],y[]):[[1],[2],[3],[4]]", + "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0],[2.0],[3.0],[4.0]]}"); + assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2],[3,4]]]", + "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0],[3.0,4.0]]]}"); + assertEncodeShortForm("tensor(x[],y[],z[]):[[[1],[2],[3],[4]]]", + "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0],[2.0],[3.0],[4.0]]]}"); + assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2,3,4]]]", + "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0,3.0,4.0]]]}"); + assertEncodeShortForm("tensor(x[],y[],z[]):[[[1]],[[2]],[[3]],[[4]]]", + "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0]],[[2.0]],[[3.0]],[[4.0]]]}"); + assertEncodeShortForm("tensor(x[],y[],z[2]):[[[1, 2]],[[3, 4]]]", + "{\"type\":\"tensor(x[],y[],z[2])\",\"value\":[[[1.0,2.0]],[[3.0,4.0]]]}"); + } + + @Test public void testInt8VectorInHexForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])")); builder.cell().label("x", 0).label("y", 0).value(2.0); @@ -274,4 +299,9 @@ public class JsonFormatTestCase { assertEncodeDecode(Tensor.from("tensor<int8>(x[2],y[2]):[2,3,5,8]")); } + private void assertEncodeShortForm(String tensor, String expected) { + byte[] json = JsonFormat.encodeShortForm((IndexedTensor) Tensor.from(tensor)); + assertEquals(expected, new String(json, StandardCharsets.UTF_8)); + } + } |