aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java6
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java30
4 files changed, 68 insertions, 0 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index a0744128a11..bbd9962be77 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -10,6 +10,7 @@ import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Slime;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.JsonFormat;
@@ -87,6 +88,11 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
}
}
Tensor result = evaluator.evaluate();
+
+ Optional<String> format = property(request, "format");
+ if (format.isPresent() && format.get().equalsIgnoreCase("short") && result instanceof IndexedTensor) {
+ return new Response(200, JsonFormat.encodeShortForm((IndexedTensor) result));
+ }
return new Response(200, JsonFormat.encode(result));
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index df89919a76e..8034be6bb22 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -183,6 +183,16 @@ public class ModelsEvaluationHandlerTest {
}
@Test
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithShortOutput() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("Placeholder", inputTensorShortForm());
+ properties.put("format", "short");
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
+ String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"value\":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]}";
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 461e73e3611..80b37e43c3d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -9,6 +9,7 @@ import com.yahoo.slime.JsonDecoder;
import com.yahoo.slime.ObjectTraverser;
import com.yahoo.slime.Slime;
import com.yahoo.slime.Type;
+import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
@@ -44,6 +45,16 @@ public class JsonFormat {
return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
}
+ /** Serializes the given tensor type and value into a short-form JSON format */
+ public static byte[] encodeShortForm(IndexedTensor tensor) {
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ root.setString("type", tensor.type().toString());
+ Cursor value = root.setArray("value");
+ encodeList(tensor, value, new long[tensor.dimensionSizes().dimensions()], 0);
+ return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
+ }
+
private static void encodeCells(Tensor tensor, Cursor rootObject) {
Cursor cellsArray = rootObject.setArray("cells");
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
@@ -59,6 +70,17 @@ public class JsonFormat {
addressObject.setString(type.dimensions().get(i).name(), address.label(i));
}
+ private static void encodeList(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) {
+ DimensionSizes sizes = tensor.dimensionSizes();
+ for (indexes[dimension] = 0; indexes[dimension] < sizes.size(dimension); ++indexes[dimension]) {
+ if (dimension < (sizes.dimensions() - 1)) {
+ encodeList(tensor, cursor.addArray(), indexes, dimension + 1);
+ } else {
+ cursor.addDouble(tensor.get(indexes));
+ }
+ }
+ }
+
/** Deserializes the given tensor from JSON format */
// NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 011c4b1fe12..2f1e3be9299 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -82,6 +83,30 @@ public class JsonFormatTestCase {
}
@Test
+ public void testDenseTensorShortForm() {
+ assertEncodeShortForm("tensor(x[]):[1.0, 2.0]",
+ "{\"type\":\"tensor(x[])\",\"value\":[1.0,2.0]}");
+ assertEncodeShortForm("tensor<float>(x[]):[1.0, 2.0]",
+ "{\"type\":\"tensor<float>(x[])\",\"value\":[1.0,2.0]}");
+ assertEncodeShortForm("tensor(x[],y[]):[[1,2,3,4]]",
+ "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0,3.0,4.0]]}");
+ assertEncodeShortForm("tensor(x[],y[]):[[1,2],[3,4]]",
+ "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0],[3.0,4.0]]}");
+ assertEncodeShortForm("tensor(x[],y[]):[[1],[2],[3],[4]]",
+ "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0],[2.0],[3.0],[4.0]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2],[3,4]]]",
+ "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0],[3.0,4.0]]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[]):[[[1],[2],[3],[4]]]",
+ "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0],[2.0],[3.0],[4.0]]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2,3,4]]]",
+ "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0,3.0,4.0]]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[]):[[[1]],[[2]],[[3]],[[4]]]",
+ "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0]],[[2.0]],[[3.0]],[[4.0]]]}");
+ assertEncodeShortForm("tensor(x[],y[],z[2]):[[[1, 2]],[[3, 4]]]",
+ "{\"type\":\"tensor(x[],y[],z[2])\",\"value\":[[[1.0,2.0]],[[3.0,4.0]]]}");
+ }
+
+ @Test
public void testInt8VectorInHexForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])"));
builder.cell().label("x", 0).label("y", 0).value(2.0);
@@ -274,4 +299,9 @@ public class JsonFormatTestCase {
assertEncodeDecode(Tensor.from("tensor<int8>(x[2],y[2]):[2,3,5,8]"));
}
+ private void assertEncodeShortForm(String tensor, String expected) {
+ byte[] json = JsonFormat.encodeShortForm((IndexedTensor) Tensor.from(tensor));
+ assertEquals(expected, new String(json, StandardCharsets.UTF_8));
+ }
+
}