summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-09-01 15:26:11 +0200
committerLester Solbakken <lesters@oath.com>2021-09-01 15:26:11 +0200
commit8286b1d9394a3f89c08b0d193d65d44e937be017 (patch)
treeae0f97745055aba39feae4ba59b9a3594a1a9b01 /vespajlib
parentb8ef0eedfd29827a98d561d65b4c657ecbadf243 (diff)
Add short form output option to model-evaluation REST API
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java30
2 files changed, 52 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 461e73e3611..80b37e43c3d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -9,6 +9,7 @@ import com.yahoo.slime.JsonDecoder;
import com.yahoo.slime.ObjectTraverser;
import com.yahoo.slime.Slime;
import com.yahoo.slime.Type;
+import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
@@ -44,6 +45,16 @@ public class JsonFormat {
return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
}
+ /** Serializes the given tensor type and value into a short-form JSON format */
+ public static byte[] encodeShortForm(IndexedTensor tensor) {
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ root.setString("type", tensor.type().toString());
+ Cursor value = root.setArray("value");
+ encodeList(tensor, value, new long[tensor.dimensionSizes().dimensions()], 0);
+ return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
+ }
+
private static void encodeCells(Tensor tensor, Cursor rootObject) {
Cursor cellsArray = rootObject.setArray("cells");
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
@@ -59,6 +70,17 @@ public class JsonFormat {
addressObject.setString(type.dimensions().get(i).name(), address.label(i));
}
+ private static void encodeList(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) {
+ DimensionSizes sizes = tensor.dimensionSizes();
+ for (indexes[dimension] = 0; indexes[dimension] < sizes.size(dimension); ++indexes[dimension]) {
+ if (dimension < (sizes.dimensions() - 1)) {
+ encodeList(tensor, cursor.addArray(), indexes, dimension + 1);
+ } else {
+ cursor.addDouble(tensor.get(indexes));
+ }
+ }
+ }
+
/** Deserializes the given tensor from JSON format */
// NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 011c4b1fe12..2f1e3be9299 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -82,6 +83,30 @@ public class JsonFormatTestCase {
}
@Test
+ public void testDenseTensorShortForm() {
+ assertEncodeShortForm("tensor(x[]):[1.0, 2.0]",
+ "{\"type\":\"tensor(x[])\",\"value\":[1.0,2.0]}");
+ assertEncodeShortForm("tensor<float>(x[]):[1.0, 2.0]",
+ "{\"type\":\"tensor<float>(x[])\",\"value\":[1.0,2.0]}");
+ assertEncodeShortForm("tensor(x[],y[]):[[1,2,3,4]]",
+ "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0,3.0,4.0]]}");
+ assertEncodeShortForm("tensor(x[],y[]):[[1,2],[3,4]]",
+ "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0],[3.0,4.0]]}");
+ assertEncodeShortForm("tensor(x[],y[]):[[1],[2],[3],[4]]",
+ "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0],[2.0],[3.0],[4.0]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2],[3,4]]]",
+ "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0],[3.0,4.0]]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[]):[[[1],[2],[3],[4]]]",
+ "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0],[2.0],[3.0],[4.0]]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2,3,4]]]",
+ "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0,3.0,4.0]]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[]):[[[1]],[[2]],[[3]],[[4]]]",
+ "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0]],[[2.0]],[[3.0]],[[4.0]]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[2]):[[[1, 2]],[[3, 4]]]",
+ "{\"type\":\"tensor(x[],y[],z[2])\",\"value\":[[[1.0,2.0]],[[3.0,4.0]]]}");
+ }
+
+ @Test
public void testInt8VectorInHexForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])"));
builder.cell().label("x", 0).label("y", 0).value(2.0);
@@ -274,4 +299,9 @@ public class JsonFormatTestCase {
assertEncodeDecode(Tensor.from("tensor<int8>(x[2],y[2]):[2,3,5,8]"));
}
+ private void assertEncodeShortForm(String tensor, String expected) {
+ byte[] json = JsonFormat.encodeShortForm((IndexedTensor) Tensor.from(tensor));
+ assertEquals(expected, new String(json, StandardCharsets.UTF_8));
+ }
+
}