diff options
4 files changed, 68 insertions, 0 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index a0744128a11..bbd9962be77 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -10,6 +10,7 @@ import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.slime.Cursor; import com.yahoo.slime.Slime; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.JsonFormat; @@ -87,6 +88,11 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { } } Tensor result = evaluator.evaluate(); + + Optional<String> format = property(request, "format"); + if (format.isPresent() && format.get().equalsIgnoreCase("short") && result instanceof IndexedTensor) { + return new Response(200, JsonFormat.encodeShortForm((IndexedTensor) result)); + } return new Response(200, JsonFormat.encode(result)); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index df89919a76e..8034be6bb22 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -183,6 +183,16 @@ public class ModelsEvaluationHandlerTest { } @Test + public void testMnistSoftmaxEvaluateSpecificFunctionWithShortOutput() { + Map<String, String> properties = new HashMap<>(); + properties.put("Placeholder", inputTensorShortForm()); + properties.put("format", "short"); + String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; + String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"value\":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]}"; + handler.assertResponse(url, properties, 200, expected); + } + + @Test public void testMnistSavedDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 461e73e3611..80b37e43c3d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -9,6 +9,7 @@ import com.yahoo.slime.JsonDecoder; import com.yahoo.slime.ObjectTraverser; import com.yahoo.slime.Slime; import com.yahoo.slime.Type; +import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; @@ -44,6 +45,16 @@ public class JsonFormat { return com.yahoo.slime.JsonFormat.toJsonBytes(slime); } + /** Serializes the given tensor type and value into a short-form JSON format */ + public static byte[] encodeShortForm(IndexedTensor tensor) { + Slime slime = new Slime(); + Cursor root = slime.setObject(); + root.setString("type", tensor.type().toString()); + Cursor value = root.setArray("value"); + encodeList(tensor, value, new long[tensor.dimensionSizes().dimensions()], 0); + return com.yahoo.slime.JsonFormat.toJsonBytes(slime); + } + private static void encodeCells(Tensor tensor, Cursor rootObject) { Cursor cellsArray = rootObject.setArray("cells"); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { @@ -59,6 +70,17 @@ public class JsonFormat { addressObject.setString(type.dimensions().get(i).name(), address.label(i)); } + private static void encodeList(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) { + DimensionSizes sizes = tensor.dimensionSizes(); + for (indexes[dimension] = 0; indexes[dimension] < sizes.size(dimension); ++indexes[dimension]) { + if (dimension < (sizes.dimensions() - 1)) { + encodeList(tensor, cursor.addArray(), indexes, dimension + 1); + } else { + cursor.addDouble(tensor.get(indexes)); + } + } + } + /** Deserializes the given tensor from JSON format */ // NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module public static Tensor decode(TensorType type, byte[] jsonTensorValue) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 011c4b1fe12..2f1e3be9299 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -82,6 +83,30 @@ public class JsonFormatTestCase { } @Test + public void testDenseTensorShortForm() { + assertEncodeShortForm("tensor(x[]):[1.0, 2.0]", + "{\"type\":\"tensor(x[])\",\"value\":[1.0,2.0]}"); + assertEncodeShortForm("tensor<float>(x[]):[1.0, 2.0]", + "{\"type\":\"tensor<float>(x[])\",\"value\":[1.0,2.0]}"); + assertEncodeShortForm("tensor(x[],y[]):[[1,2,3,4]]", + "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0,3.0,4.0]]}"); + assertEncodeShortForm("tensor(x[],y[]):[[1,2],[3,4]]", + "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0],[3.0,4.0]]}"); + assertEncodeShortForm("tensor(x[],y[]):[[1],[2],[3],[4]]", + "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0],[2.0],[3.0],[4.0]]}"); + assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2],[3,4]]]", + "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0],[3.0,4.0]]]}"); + assertEncodeShortForm("tensor(x[],y[],z[]):[[[1],[2],[3],[4]]]", + "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0],[2.0],[3.0],[4.0]]]}"); + assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2,3,4]]]", + "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0,3.0,4.0]]]}"); + assertEncodeShortForm("tensor(x[],y[],z[]):[[[1]],[[2]],[[3]],[[4]]]", + "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0]],[[2.0]],[[3.0]],[[4.0]]]}"); + assertEncodeShortForm("tensor(x[],y[],z[2]):[[[1, 2]],[[3, 4]]]", + "{\"type\":\"tensor(x[],y[],z[2])\",\"value\":[[[1.0,2.0]],[[3.0,4.0]]]}"); + } + + @Test public void testInt8VectorInHexForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])")); builder.cell().label("x", 0).label("y", 0).value(2.0); @@ -274,4 +299,9 @@ public class JsonFormatTestCase { assertEncodeDecode(Tensor.from("tensor<int8>(x[2],y[2]):[2,3,5,8]")); } + private void assertEncodeShortForm(String tensor, String expected) { + byte[] json = JsonFormat.encodeShortForm((IndexedTensor) Tensor.from(tensor)); + assertEquals(expected, new String(json, StandardCharsets.UTF_8)); + } + } |