aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-16 11:43:45 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-16 11:43:45 +0100
commit3f07bf2d9e6eae85c50aa8734694273c983f959b (patch)
treef528075cb0e877423d9d2e26d4f6925f6ff9784c
parent416f596b150ec159717bfd2f9b2ef70e4d4cd3dd (diff)
Test direct rendering
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java11
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java226
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java3
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java14
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java50
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java22
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java86
-rw-r--r--vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java66
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java69
-rw-r--r--vespajlib/src/main/java/com/yahoo/text/JSON.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java243
11 files changed, 616 insertions, 178 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
index 9498f860f88..352a31553e7 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
@@ -45,6 +45,7 @@ import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
import com.yahoo.search.result.NanNumber;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.JsonFormat;
import java.io.IOException;
@@ -817,14 +818,8 @@ public class JsonRenderer extends AsynchronousSectionedRenderer<Result> {
}
private void renderTensor(Optional<Tensor> tensor) throws IOException {
- if (tensor.isEmpty()) {
- generator().writeStartObject();
- generator().writeArrayFieldStart("cells");
- generator().writeEndArray();
- generator().writeEndObject();
- return;
- }
- generator().writeRawValue(new String(JsonFormat.encode(tensor.get(), settings.tensorShortForm, settings.tensorDirectValues),
+ generator().writeRawValue(new String(JsonFormat.encode(tensor.orElse(Tensor.Builder.of(TensorType.empty).build()),
+ settings.tensorShortForm, settings.tensorDirectValues),
StandardCharsets.UTF_8));
}
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java
index c1ede03a371..b3ed85911b9 100644
--- a/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java
@@ -53,6 +53,7 @@ import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.text.JSON;
import com.yahoo.text.Utf8;
import com.yahoo.yolean.Exceptions;
import com.yahoo.yolean.trace.TraceNode;
@@ -156,37 +157,136 @@ public class JsonRendererTestCase {
r.hits().add(h);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@Timeout(300)
- void testTensorShortForm() throws ExecutionException, InterruptedException, IOException {
- String expected = "{" +
- "\"root\":{" +
- "\"id\":\"toplevel\"," +
- "\"relevance\":1.0," +
- "\"fields\":{" +
- "\"totalCount\":1" +
- "}," +
- "\"children\":[{" +
- "\"id\":\"tensors\"," +
- "\"relevance\":1.0," +
- "\"fields\":{" +
- "\"tensor_standard\":{\"type\":\"tensor(x{},y{})\",\"cells\":[{\"address\":{\"x\":\"a\",\"y\":\"0\"},\"value\":1.0},{\"address\":{\"x\":\"b\",\"y\":\"1\"},\"value\":2.0}]}," +
- "\"tensor_indexed\":{\"type\":\"tensor(x[2],y[3])\",\"values\":[[1.0,2.0,3.0],[4.0,5.0,6.0]]}," +
- "\"tensor_single_mapped\":{\"type\":\"tensor(x{})\",\"cells\":{\"a\":1.0,\"b\":2.0}}," +
- "\"tensor_mixed\":{\"type\":\"tensor(x{},y[2])\",\"blocks\":{\"a\":[1.0,2.0],\"b\":[3.0,4.0]}}," +
- "\"summaryfeatures\":{" +
- "\"tensor_standard\":{\"type\":\"tensor(x{},y{})\",\"cells\":[{\"address\":{\"x\":\"a\",\"y\":\"0\"},\"value\":1.0},{\"address\":{\"x\":\"b\",\"y\":\"1\"},\"value\":2.0}]}," +
- "\"tensor_indexed\":{\"type\":\"tensor(x[2],y[3])\",\"values\":[[1.0,2.0,3.0],[4.0,5.0,6.0]]}," +
- "\"tensor_single_mapped\":{\"type\":\"tensor(x{})\",\"cells\":{\"a\":1.0,\"b\":2.0}}," +
- "\"tensor_mixed\":{\"type\":\"tensor(x{},y[2])\",\"blocks\":{\"a\":[1.0,2.0],\"b\":[3.0,4.0]}}" +
- "}" +
- "}" +
- "}]" +
- "}}\n";
+ void testTensorRendering() throws ExecutionException, InterruptedException, IOException {
+ String shortJson = """
+ {
+ "root": {
+ "id":"toplevel",
+ "relevance":1.0,
+ "fields":{
+ "totalCount":1
+ },
+ "children":[{
+ "id":"tensors",
+ "relevance":1.0,
+ "fields":{
+ "tensor_standard":{"type":"tensor(x{},y{})","cells":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"b","y":"1"},"value":2.0}]},
+ "tensor_indexed":{"type":"tensor(x[2],y[3])","values":[[1.0,2.0,3.0],[4.0,5.0,6.0]]},
+ "tensor_single_mapped":{"type":"tensor(x{})","cells":{"a":1.0,"b":2.0}},
+ "tensor_mixed":{"type":"tensor(x{},y[2])","blocks":{"a":[1.0,2.0],"b":[3.0,4.0]}},
+ "summaryfeatures":{
+ "tensor_standard":{"type":"tensor(x{},y{})","cells":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"b","y":"1"},"value":2.0}]},
+ "tensor_indexed":{"type":"tensor(x[2],y[3])","values":[[1.0,2.0,3.0],[4.0,5.0,6.0]]},
+ "tensor_single_mapped":{"type":"tensor(x{})","cells":{"a":1.0,"b":2.0}},
+ "tensor_mixed":{"type":"tensor(x{},y[2])","blocks":{"a":[1.0,2.0],"b":[3.0,4.0]}}
+ }
+ }
+ }]
+ }
+ }""";
+
+ String longJson = """
+ {
+ "root": {
+ "id":"toplevel",
+ "relevance":1.0,
+ "fields":{
+ "totalCount":1
+ },
+ "children":[{
+ "id":"tensors",
+ "relevance":1.0,
+ "fields":{
+ "tensor_standard":{"type":"tensor(x{},y{})","cells":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"b","y":"1"},"value":2.0}]},
+ "tensor_indexed":{"type":"tensor(x[2],y[3])","cells":[{"address":{"x":"0","y":"0"},"value":1.0},{"address":{"x":"0","y":"1"},"value":2.0},{"address":{"x":"0","y":"2"},"value":3.0},{"address":{"x":"1","y":"0"},"value":4.0},{"address":{"x":"1","y":"1"},"value":5.0},{"address":{"x":"1","y":"2"},"value":6.0}]},
+ "tensor_single_mapped":{"type":"tensor(x{})","cells":[{"address":{"x":"a"},"value":1.0},{"address":{"x":"b"},"value":2.0}]},
+ "tensor_mixed":{"type":"tensor(x{},y[2])","cells":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"a","y":"1"},"value":2.0},{"address":{"x":"b","y":"0"},"value":3.0},{"address":{"x":"b","y":"1"},"value":4.0}]},
+ "summaryfeatures":{
+ "tensor_standard":{"type":"tensor(x{},y{})","cells":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"b","y":"1"},"value":2.0}]},
+ "tensor_indexed":{"type":"tensor(x[2],y[3])","cells":[{"address":{"x":"0","y":"0"},"value":1.0},{"address":{"x":"0","y":"1"},"value":2.0},{"address":{"x":"0","y":"2"},"value":3.0},{"address":{"x":"1","y":"0"},"value":4.0},{"address":{"x":"1","y":"1"},"value":5.0},{"address":{"x":"1","y":"2"},"value":6.0}]},
+ "tensor_single_mapped":{"type":"tensor(x{})","cells":[{"address":{"x":"a"},"value":1.0},{"address":{"x":"b"},"value":2.0}]},
+ "tensor_mixed":{"type":"tensor(x{},y[2])","cells":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"a","y":"1"},"value":2.0},{"address":{"x":"b","y":"0"},"value":3.0},{"address":{"x":"b","y":"1"},"value":4.0}]}
+ }
+ }
+ }]
+ }
+ }""";
+
+ String shortDirectJson = """
+ {
+ "root": {
+ "id":"toplevel",
+ "relevance":1.0,
+ "fields":{
+ "totalCount":1
+ },
+ "children":[{
+ "id":"tensors",
+ "relevance":1.0,
+ "fields":{
+ "tensor_standard":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"b","y":"1"},"value":2.0}],
+ "tensor_indexed":[[1.0,2.0,3.0],[4.0,5.0,6.0]],
+ "tensor_single_mapped":{"a":1.0,"b":2.0},
+ "tensor_mixed":{"a":[1.0,2.0],"b":[3.0,4.0]},
+ "summaryfeatures":{
+ "tensor_standard":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"b","y":"1"},"value":2.0}],
+ "tensor_indexed":[[1.0,2.0,3.0],[4.0,5.0,6.0]],
+ "tensor_single_mapped":{"a":1.0,"b":2.0},
+ "tensor_mixed":{"a":[1.0,2.0],"b":[3.0,4.0]}
+ }
+ }
+ }]
+ }
+ }""";
+
+ String longDirectJson = """
+ {
+ "root": {
+ "id":"toplevel",
+ "relevance":1.0,
+ "fields":{
+ "totalCount":1
+ },
+ "children":[{
+ "id":"tensors",
+ "relevance":1.0,
+ "fields":{
+ "tensor_standard":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"b","y":"1"},"value":2.0}],
+ "tensor_indexed":[{"address":{"x":"0","y":"0"},"value":1.0},{"address":{"x":"0","y":"1"},"value":2.0},{"address":{"x":"0","y":"2"},"value":3.0},{"address":{"x":"1","y":"0"},"value":4.0},{"address":{"x":"1","y":"1"},"value":5.0},{"address":{"x":"1","y":"2"},"value":6.0}],
+ "tensor_single_mapped":[{"address":{"x":"a"},"value":1.0},{"address":{"x":"b"},"value":2.0}],
+ "tensor_mixed":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"a","y":"1"},"value":2.0},{"address":{"x":"b","y":"0"},"value":3.0},{"address":{"x":"b","y":"1"},"value":4.0}],
+ "summaryfeatures":{
+ "tensor_standard":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"b","y":"1"},"value":2.0}],
+ "tensor_indexed":[{"address":{"x":"0","y":"0"},"value":1.0},{"address":{"x":"0","y":"1"},"value":2.0},{"address":{"x":"0","y":"2"},"value":3.0},{"address":{"x":"1","y":"0"},"value":4.0},{"address":{"x":"1","y":"1"},"value":5.0},{"address":{"x":"1","y":"2"},"value":6.0}],
+ "tensor_single_mapped":[{"address":{"x":"a"},"value":1.0},{"address":{"x":"b"},"value":2.0}],
+ "tensor_mixed":[{"address":{"x":"a","y":"0"},"value":1.0},{"address":{"x":"a","y":"1"},"value":2.0},{"address":{"x":"b","y":"0"},"value":3.0},{"address":{"x":"b","y":"1"},"value":4.0}]
+ }
+ }
+ }]
+ }
+ }""";
+
+ assertTensorRendering(shortJson, "short");
+ assertTensorRendering(longJson, "long");
+ assertTensorRendering(shortDirectJson, "short-value");
+ assertTensorRendering(longDirectJson, "long-value");
+ try {
+ render(new Result(new Query("/?presentation.format.tensors=unknown")));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("Could not set 'presentation.format.tensors' to 'unknown': Value must be 'long', 'short', 'long-value', or 'short-value', not 'unknown'",
+ Exceptions.toMessageString(e));
+ }
+ }
+
+ private void assertTensorRendering(String expected, String format) throws ExecutionException, InterruptedException, IOException {
Slime slime = new Slime();
Cursor features = slime.setObject();
features.setData("tensor_standard", TypedBinaryFormat.encode(Tensor.from("tensor(x{},y{}):{ {x:a,y:0}:1.0, {x:b,y:1}:2.0 }")));
@@ -202,26 +302,16 @@ public class JsonRendererTestCase {
h.setField("tensor_mixed", new TensorFieldValue(Tensor.from("tensor(x{},y[2]):{a:[1,2], b:[3,4]}")));
h.setField("summaryfeatures", summaryFeatures);
- Result result1 = new Result(new Query("/?presentation.format.tensors=short"));
+ Result result1 = new Result(new Query("/?presentation.format.tensors=" + format));
result1.hits().add(h);
result1.setTotalHitCount(1L);
- String summary1 = render(result1);
- assertEqualJson(expected, summary1);
+ assertEqualJson(expected, render(result1));
- Result result2 = new Result(new Query("/?format.tensors=short"));
+ // Alias
+ Result result2 = new Result(new Query("/?format.tensors=" + format));
result2.hits().add(h);
result2.setTotalHitCount(1L);
- String summary2 = render(result2);
- assertEqualJson(expected, summary2);
-
- try {
- render(new Result(new Query("/?presentation.format.tensors=unknown")));
- fail("Expected exception");
- }
- catch (IllegalArgumentException e) {
- assertEquals("Could not set 'presentation.format.tensors' to 'unknown': Value must be 'long', 'short', 'long-value', or 'short-value', not 'unknown'",
- Exceptions.toMessageString(e));
- }
+ assertEqualJson(expected, render(result2));
}
@Test
@@ -241,7 +331,7 @@ public class JsonRendererTestCase {
+ " \"string\": \"stuff\","
+ " \"predicate\": \"a in [b]\","
+ " \"tensor1\": { \"type\": \"tensor(x{})\", \"cells\": { \"a\":2.0 } },"
- + " \"tensor2\": { \"cells\": [] },"
+ + " \"tensor2\": { \"type\": \"tensor()\", \"values\":[0.0] },"
+ " \"tensor3\": { \"type\": \"tensor(x{},y{})\", \"cells\": [ { \"address\": {\"x\": \"a\", \"y\": \"0\"}, \"value\":2.0 }, { \"address\": {\"x\": \"a\", \"y\": \"1\"}, \"value\":-1.0 } ] },"
+ " \"summaryfeatures\": {"
+ " \"scalar1\":1.5,"
@@ -281,7 +371,7 @@ public class JsonRendererTestCase {
r.hits().add(h);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
private FeatureData createSummaryFeatures() {
@@ -349,7 +439,7 @@ public class JsonRendererTestCase {
subQuery.trace("yellow", 1);
q.trace("marker", 1);
String summary = render(execution, r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -415,7 +505,7 @@ public class JsonRendererTestCase {
subQuery.trace(access, 1);
q.trace("marker", 1);
String summary = render(execution, r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -441,7 +531,7 @@ public class JsonRendererTestCase {
subQuery.trace("yellow", 1);
q.trace("marker", 1);
String summary = render(execution, r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@SuppressWarnings({"unchecked"})
@@ -562,7 +652,7 @@ public class JsonRendererTestCase {
execution.trace().traceNode().add(child);
q.trace("something", 1);
String summary = render(execution, r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -604,7 +694,7 @@ public class JsonRendererTestCase {
execution.trace().traceNode().add(child);
q.trace("something", 1);
String summary = render(execution, r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -653,7 +743,7 @@ public class JsonRendererTestCase {
childOfChild.add(new TraceNode("in OO languages, nesting is for birds", 0L));
execution.trace().traceNode().add(child);
String summary = render(execution, r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -741,7 +831,7 @@ public class JsonRendererTestCase {
r.hits().add(gg);
r.hits().addError(ErrorMessage.createInternalServerError("boom"));
String summary = render(execution, r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -776,7 +866,7 @@ public class JsonRendererTestCase {
r.setCoverage(new Coverage(500, 600).setDegradedReason(5));
String summary = render(execution, r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -813,7 +903,7 @@ public class JsonRendererTestCase {
r.hits().add(h);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -843,7 +933,7 @@ public class JsonRendererTestCase {
r.hits().add(h);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -871,7 +961,7 @@ public class JsonRendererTestCase {
r.hits().add(h);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -905,7 +995,7 @@ public class JsonRendererTestCase {
ErrorMessage e = new ErrorMessage(1234, "hello", "top of the day", t);
r.hits().addError(e);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -999,7 +1089,7 @@ public class JsonRendererTestCase {
r.hits().add(rg);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -1063,7 +1153,7 @@ public class JsonRendererTestCase {
r.hits().add(rg);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -1110,7 +1200,7 @@ public class JsonRendererTestCase {
h.setField("json producer", struct);
r.hits().add(h);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -1146,7 +1236,7 @@ public class JsonRendererTestCase {
r.hits().add(h);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -1172,7 +1262,7 @@ public class JsonRendererTestCase {
r.hits().add(h);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -1204,7 +1294,7 @@ public class JsonRendererTestCase {
r.hits().add(h);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -1240,7 +1330,7 @@ public class JsonRendererTestCase {
r.getElapsedTime().add(t);
renderer.setTimeSource(() -> 8L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
@Test
@@ -1278,7 +1368,7 @@ public class JsonRendererTestCase {
String json = summary.substring(jsonCallback.length() + 1, summary.length() - 2);
assertEquals(jsonCallback + "(", jsonCallbackBegin);
- assertEqualJson(expected, json);
+ assertEqualJsonContent(expected, json);
assertEquals(");", jsonCallbackEnd);
}
@@ -1327,7 +1417,7 @@ public class JsonRendererTestCase {
r.hits().add(h);
r.setTotalHitCount(1L);
String summary = render(r);
- assertEqualJson(expected, summary);
+ assertEqualJsonContent(expected, summary);
}
private static SlimeAdapter dataFromSimplified(String simplified) {
@@ -1512,8 +1602,14 @@ public class JsonRendererTestCase {
}
}
+ private void assertEqualJson(String expected, String generated) {
+ assertEquals("", validateJSON(expected));
+ assertEquals("", validateJSON(generated));
+ assertEquals(JSON.canonical(expected), JSON.canonical(generated));
+ }
+
@SuppressWarnings("unchecked")
- private void assertEqualJson(String expected, String generated) throws IOException {
+ private void assertEqualJsonContent(String expected, String generated) throws IOException {
assertEquals("", validateJSON(expected));
assertEquals("", validateJSON(generated));
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
index 7003d19e7d1..105739da508 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
@@ -166,9 +166,8 @@ public class TensorFieldValue extends FieldValue {
@Override
public boolean equals(Object o) {
if (this == o) return true;
- if ( ! (o instanceof TensorFieldValue)) return false;
+ if ( ! (o instanceof TensorFieldValue other)) return false;
- TensorFieldValue other = (TensorFieldValue)o;
if ( ! getTensorType().equals(other.getTensorType())) return false;
if ( ! getTensor().equals(other.getTensor())) return false;
return true;
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index 914ab670142..9d66bbd25d9 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -4,6 +4,8 @@ package com.yahoo.document.json.readers;
import com.fasterxml.jackson.core.JsonToken;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
+import com.yahoo.slime.Inspector;
+import com.yahoo.slime.Type;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
@@ -179,10 +181,12 @@ public class TensorReader {
/** Reads a tensor value directly at the root, where the format is decided by the tensor type. */
private static void readDirectTensorValue(TokenBuffer buffer, Tensor.Builder builder) {
- boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
+ boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed) && 1==2;
boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
- if ( ! hasMapped)
+ if (isArrayOfObjects(buffer, 0))
+ readTensorCells(buffer, builder);
+ else if ( ! hasMapped)
readTensorValues(buffer, builder);
else if (hasMapped && hasIndexed)
readTensorBlocks(buffer, builder);
@@ -190,6 +194,12 @@ public class TensorReader {
readTensorCells(buffer, builder);
}
+ private static boolean isArrayOfObjects(TokenBuffer buffer, int ahead) {
+ if (buffer.peek(ahead++) != JsonToken.START_ARRAY) return false;
+ if (buffer.peek(ahead) == JsonToken.START_ARRAY) return isArrayOfObjects(buffer, ahead); // nested array
+ return buffer.peek(ahead) == JsonToken.START_OBJECT;
+ }
+
private static TensorAddress readAddress(TokenBuffer buffer, TensorType type) {
expectObjectStart(buffer.current());
TensorAddress.Builder builder = new TensorAddress.Builder(type);
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index c19094ff231..f48c2330f82 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1427,14 +1427,16 @@ public class JsonReaderTestCase {
@Test
public void testParsingOfSparseTensorWithCells() {
Tensor tensor = assertSparseTensorField("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}",
- createPutWithSparseTensor(inputJson("{",
- " 'cells': [",
- " { 'address': { 'x': 'a', 'y': 'b' },",
- " 'value': 2.0 },",
- " { 'address': { 'x': 'c', 'y': 'b' },",
- " 'value': 3.0 }",
- " ]",
- "}")));
+ createPutWithSparseTensor(
+ """
+ {
+ "type": "tensor(x{},y{})",
+ "cells": [
+ { "address": { "x": "a", "y": "b" }, "value": 2.0 },
+ { "address": { "x": "c", "y": "b" }, "value": 3.0 }
+ ]
+ }
+ """));
assertTrue(tensor instanceof MappedTensor); // any functional instance is fine
}
@@ -1542,13 +1544,33 @@ public class JsonReaderTestCase {
builder.cell().label("x", 1).label("y", 2).value(7.0);
Tensor expected = builder.build();
- String mixedJson = "{\"blocks\":[" +
- "{\"address\":{\"x\":\"0\"},\"values\":[2.0,3.0,4.0]}," +
- "{\"address\":{\"x\":\"1\"},\"values\":[5.0,6.0,7.0]}" +
- "]}";
+ String mixedJson =
+ """
+ {
+ "blocks":[
+ {"address":{"x":"0"},"values":[2.0,3.0,4.0]},
+ {"address":{"x":"1"},"values":[5.0,6.0,7.0]}
+ ]
+ }
+ """;
Tensor tensor = assertTensorField(expected,
createPutWithTensor(inputJson(mixedJson), "mixed_tensor"), "mixed_tensor");
assertTrue(tensor instanceof MixedTensor); // this matters for performance
+
+ String mixedJsonDirect =
+ """
+ [
+ {"address":{"x":"0","y":"0"},"value":2.0},
+ {"address":{"x":"0","y":"1"},"value":3.0},
+ {"address":{"x":"0","y":"2"},"value":4.0},
+ {"address":{"x":"1","y":"0"},"value":5.0},
+ {"address":{"x":"1","y":"1"},"value":6.0},
+ {"address":{"x":"1","y":"2"},"value":7.0}
+ ]
+ """;
+ Tensor tensorDirect = assertTensorField(expected,
+ createPutWithTensor(inputJson(mixedJsonDirect), "mixed_tensor"), "mixed_tensor");
+ assertTrue(tensorDirect instanceof MixedTensor); // this matters for performance
}
@Test
@@ -1602,8 +1624,8 @@ public class JsonReaderTestCase {
@Test
public void testAssignUpdateOfNullTensor() {
ClearValueUpdate clearUpdate = (ClearValueUpdate) getTensorField(createAssignUpdateWithSparseTensor(null)).getValueUpdate(0);
- assertTrue(clearUpdate != null);
- assertTrue(clearUpdate.getValue() == null);
+ assertNotNull(clearUpdate);
+ assertNull(clearUpdate.getValue());
}
@Test
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
index 5fabfca8737..6c4dd886f4b 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -6,6 +6,7 @@ import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.JsonFormat;
+import com.yahoo.text.JSON;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -26,11 +27,18 @@ class HandlerTester {
}
private static Predicate<String> matchString(String expected) {
return s -> {
- // System.out.println("Expected: " + expected);
- // System.out.println("Actual: " + s);
+ //System.out.println("Expected: " + expected);
+ //System.out.println("Actual: " + s);
return expected.equals(s);
};
}
+ private static Predicate<String> matchJsonString(String expected) {
+ return s -> {
+ //System.out.println("Expected: " + expected);
+ //System.out.println("Actual: " + s);
+ return JSON.canonical(expected).equals(JSON.canonical(s));
+ };
+ }
public static Predicate<String> matchJson(String... expectedJson) {
var jExp = String.join("\n", expectedJson).replaceAll("'", "\"");
var expected = jsonToSlime(jExp);
@@ -72,6 +80,10 @@ class HandlerTester {
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) {
+ checkResponse(url, properties, expectedCode, matchJsonString(expectedResult), headers);
+ }
+
+ void assertStringResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) {
checkResponse(url, properties, expectedCode, matchString(expectedResult), headers);
}
@@ -91,15 +103,11 @@ class HandlerTester {
assertResponse(getRequest, expectedCode, expectedResult);
}
- void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
- checkResponse(request, expectedCode, matchString(expectedResult));
- }
-
void checkResponse(HttpRequest request, int expectedCode, Predicate<String> check) {
HttpResponse response = handler.handle(request);
assertEquals("application/json", response.getContentType());
- assertEquals(expectedCode, response.getStatus());
assertEquals(true, check.test(getContents(response)));
+ assertEquals(expectedCode, response.getStatus());
}
void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 50dbecaffce..9b2b793212b 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -191,22 +191,82 @@ public class ModelsEvaluationHandlerTest {
}
@Test
- public void testMnistSoftmaxEvaluateSpecificFunctionWithBindingsShortForm() {
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithShortOutput() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("Placeholder", inputTensorShortForm());
+ properties.put("format.tensors", "short");
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
+ String expected =
+ """
+ {
+ "type":"tensor(d0[],d1[10])",
+ "values":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]
+ }
+ """;
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithLongOutput() {
Map<String, String> properties = new HashMap<>();
properties.put("Placeholder", inputTensorShortForm());
properties.put("format.tensors", "long");
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
- String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
+ String expected =
+ """
+ {
+ "type":"tensor(d0[],d1[10])",
+ "cells":[
+ {"address":{"d0":"0","d1":"0"},"value":-0.3546536862850189},
+ {"address":{"d0":"0","d1":"1"},"value":0.3759574592113495},
+ {"address":{"d0":"0","d1":"2"},"value":0.06054411828517914},
+ {"address":{"d0":"0","d1":"3"},"value":-0.251544713973999},
+ {"address":{"d0":"0","d1":"4"},"value":0.017951013520359993},
+ {"address":{"d0":"0","d1":"5"},"value":1.2899067401885986},
+ {"address":{"d0":"0","d1":"6"},"value":-0.10389615595340729},
+ {"address":{"d0":"0","d1":"7"},"value":0.6367976665496826},
+ {"address":{"d0":"0","d1":"8"},"value":-1.4136744737625122},
+ {"address":{"d0":"0","d1":"9"},"value":-0.2573896050453186}
+ ]
+ }
+ """;
handler.assertResponse(url, properties, 200, expected);
}
@Test
- public void testMnistSoftmaxEvaluateSpecificFunctionWithShortOutput() {
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithShortDirectOutput() {
Map<String, String> properties = new HashMap<>();
properties.put("Placeholder", inputTensorShortForm());
- properties.put("format.tensors", "short");
+ properties.put("format.tensors", "short-value");
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
- String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"values\":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]}";
+ String expected =
+ """
+ [[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]
+ """;
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithLongDirectOutput() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("Placeholder", inputTensorShortForm());
+ properties.put("format.tensors", "long-value");
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
+ String expected =
+ """
+ [
+ {"address":{"d0":"0","d1":"0"},"value":-0.3546536862850189},
+ {"address":{"d0":"0","d1":"1"},"value":0.3759574592113495},
+ {"address":{"d0":"0","d1":"2"},"value":0.06054411828517914},
+ {"address":{"d0":"0","d1":"3"},"value":-0.251544713973999},
+ {"address":{"d0":"0","d1":"4"},"value":0.017951013520359993},
+ {"address":{"d0":"0","d1":"5"},"value":1.2899067401885986},
+ {"address":{"d0":"0","d1":"6"},"value":-0.10389615595340729},
+ {"address":{"d0":"0","d1":"7"},"value":0.6367976665496826},
+ {"address":{"d0":"0","d1":"8"},"value":-1.4136744737625122},
+ {"address":{"d0":"0","d1":"9"},"value":-0.2573896050453186}
+ ]
+ """;
handler.assertResponse(url, properties, 200, expected);
}
@@ -251,14 +311,14 @@ public class ModelsEvaluationHandlerTest {
Map<String, String> properties = new HashMap<>();
properties.put("format.tensors", "string");
String url = "http://localhost/model-evaluation/v1/vespa_model/";
- handler.assertResponse(url + "test_mapped/eval", properties, 200,
- "tensor(d0{}):{a:1.0, b:2.0}");
- handler.assertResponse(url + "test_indexed/eval", properties, 200,
- "tensor(d0[2],d1[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]");
- handler.assertResponse(url + "test_mixed/eval", properties, 200,
- "tensor(x{},y[3]):{a:[1.0, 2.0, 3.0], b:[4.0, 5.0, 6.0]}");
- handler.assertResponse(url + "test_mixed_2/eval", properties, 200,
- "tensor(a[2],b[2],c{},d[2]):{a:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], b:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]}");
+ handler.assertStringResponse(url + "test_mapped/eval", properties, 200,
+ "tensor(d0{}):{a:1.0, b:2.0}", Map.of());
+ handler.assertStringResponse(url + "test_indexed/eval", properties, 200,
+ "tensor(d0[2],d1[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]", Map.of());
+ handler.assertStringResponse(url + "test_mixed/eval", properties, 200,
+ "tensor(x{},y[3]):{a:[1.0, 2.0, 3.0], b:[4.0, 5.0, 6.0]}", Map.of());
+ handler.assertStringResponse(url + "test_mixed_2/eval", properties, 200,
+ "tensor(a[2],b[2],c{},d[2]):{a:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], b:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]}", Map.of());
}
@Test
diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java b/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java
index cc6b8567b03..b6ad7ba5570 100644
--- a/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java
+++ b/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java
@@ -521,16 +521,66 @@ public class DocumentV1ApiTest {
parameters.responseHandler().get().handleResponse(new DocumentResponse(0, doc1));
return new Result();
});
+ // -- short tensors
+ response = driver.sendRequest("http://localhost/document/v1/space/music/docid/one?format.tensors=short");
+ String shortJson =
+ """
+ {
+ "pathId": "/document/v1/space/music/docid/one",
+ "id": "id:space:music::one",
+ "fields": {
+ "artist": "Tom Waits",
+ "embedding": { "type": "tensor(x[3])","values": [1.0, 2.0, 3.0]}
+ }
+ }
+ """;
+ assertEquals(200, response.getStatus());
+ assertSameJson(shortJson, response.readAll());
+ // -- long tensors
response = driver.sendRequest("http://localhost/document/v1/space/music/docid/one?format.tensors=long");
- assertSameJson("{" +
- " \"pathId\": \"/document/v1/space/music/docid/one\"," +
- " \"id\": \"id:space:music::one\"," +
- " \"fields\": {" +
- " \"artist\": \"Tom Waits\"," +
- " \"embedding\": { \"type\": \"tensor(x[3])\",\"cells\": [{\"address\":{\"x\":\"0\"},\"value\":1.0},{\"address\":{\"x\":\"1\"},\"value\": 2.0},{\"address\":{\"x\":\"2\"},\"value\": 3.0}]}" +
- " }" +
- "}", response.readAll());
+ String longJson =
+ """
+ {
+ "pathId": "/document/v1/space/music/docid/one",
+ "id": "id:space:music::one",
+ "fields": {
+ "artist": "Tom Waits",
+ "embedding": { "type": "tensor(x[3])","cells": [{"address":{"x":"0"},"value":1.0},{"address":{"x":"1"},"value": 2.0},{"address":{"x":"2"},"value": 3.0}]}
+ }
+ }
+ """;
+ assertEquals(200, response.getStatus());
+ assertSameJson(longJson, response.readAll());
+ // -- short direct tensors
+ response = driver.sendRequest("http://localhost/document/v1/space/music/docid/one?format.tensors=short-value");
+ String shortDirectJson =
+ """
+ {
+ "pathId": "/document/v1/space/music/docid/one",
+ "id": "id:space:music::one",
+ "fields": {
+ "artist": "Tom Waits",
+ "embedding": [1.0, 2.0, 3.0]}
+ }
+ }
+ """;
+ assertEquals(200, response.getStatus());
+ assertSameJson(shortDirectJson, response.readAll());
+ // -- long direct tensors
+ response = driver.sendRequest("http://localhost/document/v1/space/music/docid/one?format.tensors=long-value");
+ String longDirectJson =
+ """
+ {
+ "pathId": "/document/v1/space/music/docid/one",
+ "id": "id:space:music::one",
+ "fields": {
+ "artist": "Tom Waits",
+ "embedding": [{"address":{"x":"0"},"value":1.0},{"address":{"x":"1"},"value": 2.0},{"address":{"x":"2"},"value": 3.0}]
+ }
+ }
+ """;
assertEquals(200, response.getStatus());
+ assertSameJson(longDirectJson, response.readAll());
// GET with not encoded / in user specified part of document id is perfectly OK ... щ(ಥДಥщ)
access.session.expect((id, parameters) -> {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 68997c82d3e..b7e6e67ce73 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -46,13 +46,13 @@ public class JsonFormat {
*/
public static byte[] encode(Tensor tensor, boolean shortForm, boolean directValues) {
Slime slime = new Slime();
- if (shortForm) {
- Cursor root = null;
- if ( ! directValues) {
- root = slime.setObject();
- root.setString("type", tensor.type().toString());
- }
+ Cursor root = null;
+ if ( ! directValues) {
+ root = slime.setObject();
+ root.setString("type", tensor.type().toString());
+ }
+ if (shortForm) {
if (tensor instanceof IndexedTensor denseTensor) {
// Encode as nested lists if indexed tensor
Cursor parent = root == null ? slime.setArray() : root.setArray("values");
@@ -77,9 +77,8 @@ public class JsonFormat {
return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
}
else {
- Cursor root = slime.setObject();
- root.setString("type", tensor.type().toString());
- encodeCells(tensor, root.setArray("cells"));
+ Cursor parent = root == null ? slime.setArray() : root.setArray("cells");
+ encodeCells(tensor, parent);
}
return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
}
@@ -241,48 +240,52 @@ public class JsonFormat {
}
private static void decodeValues(Inspector values, Tensor.Builder builder) {
+ decodeValues(values, builder, new MutableInteger(0));
+ }
+
+ private static void decodeValues(Inspector values, Tensor.Builder builder, MutableInteger index) {
if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder))
- throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
- "Use 'cells' or 'blocks' instead");
+ throw new IllegalArgumentException("An array of values can only be used with a dense tensor. Use a map instead");
if (values.type() == Type.STRING) {
double[] decoded = decodeHexString(values.asString(), builder.type().valueType());
if (decoded.length == 0)
- throw new IllegalArgumentException("The 'values' string does not contain any values");
+ throw new IllegalArgumentException("The values string does not contain any values");
for (int i = 0; i < decoded.length; i++) {
indexedBuilder.cellByDirectIndex(i, decoded[i]);
}
return;
}
if (values.type() != Type.ARRAY)
- throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type());
+ throw new IllegalArgumentException("Excepted values to be an array, not " + values.type());
if (values.entries() == 0)
- throw new IllegalArgumentException("The 'values' array does not contain any values");
+ throw new IllegalArgumentException("The values array does not contain any values");
- MutableInteger index = new MutableInteger(0);
values.traverse((ArrayTraverser) (__, value) -> {
- if (value.type() != Type.LONG && value.type() != Type.DOUBLE) {
- throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type());
- }
- indexedBuilder.cellByDirectIndex(index.next(), value.asDouble());
+ if (value.type() == Type.ARRAY)
+ decodeValues(value, builder, index);
+ else if (value.type() == Type.LONG || value.type() == Type.DOUBLE)
+ indexedBuilder.cellByDirectIndex(index.next(), value.asDouble());
+ else
+ throw new IllegalArgumentException("Excepted the values array to contain numbers or nested arrays, not " + value.type());
});
}
private static void decodeBlocks(Inspector values, Tensor.Builder builder) {
if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder))
- throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " +
- "Use 'cells' or 'values' instead");
+ throw new IllegalArgumentException("Blocks of values can only be used with mixed (sparse and dense) tensors." +
+ "Use an array of cell values instead.");
if (values.type() == Type.ARRAY)
values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder));
else if (values.type() == Type.OBJECT)
values.traverse((ObjectTraverser) (key, value) -> decodeSingleDimensionBlock(key, value, mixedBuilder));
else
- throw new IllegalArgumentException("Excepted 'blocks' to contain an array or object, not " + values.type());
+ throw new IllegalArgumentException("Excepted the block to contain an array or object, not " + values.type());
}
private static void decodeBlock(Inspector block, MixedTensor.BoundBuilder mixedBuilder) {
if (block.type() != Type.OBJECT)
- throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an object, not " + block.type());
+ throw new IllegalArgumentException("Expected an item in a blocks array to be an object, not " + block.type());
mixedBuilder.block(decodeAddress(block.field("address"), mixedBuilder.type().mappedSubtype()),
decodeValues(block.field("values"), mixedBuilder));
}
@@ -292,7 +295,9 @@ public class JsonFormat {
boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
- if ( ! hasMapped)
+ if (isArrayOfObjects(root))
+ decodeCells(root, builder);
+ else if ( ! hasMapped)
decodeValues(root, builder);
else if (hasMapped && hasIndexed)
decodeBlocks(root, builder);
@@ -300,9 +305,17 @@ public class JsonFormat {
decodeCells(root, builder);
}
+ private static boolean isArrayOfObjects(Inspector inspector) {
+ if (inspector.type() != Type.ARRAY) return false;
+ if (inspector.entries() == 0) return false;
+ Inspector firstItem = inspector.entry(0);
+ if (firstItem.type() == Type.ARRAY) return isArrayOfObjects(firstItem);
+ return firstItem.type() == Type.OBJECT;
+ }
+
private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) {
if (value.type() != Type.ARRAY)
- throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type());
+ throw new IllegalArgumentException("Expected an item in a blocks array to be an array, not " + value.type());
mixedBuilder.block(asAddress(key, mixedBuilder.type().mappedSubtype()),
decodeValues(value, mixedBuilder));
}
@@ -386,19 +399,19 @@ public class JsonFormat {
double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
if (valuesField.type() == Type.ARRAY) {
if (valuesField.entries() == 0) {
- throw new IllegalArgumentException("The 'block' value array does not contain any values");
+ throw new IllegalArgumentException("The block value array does not contain any values");
}
valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value));
} else if (valuesField.type() == Type.STRING) {
double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType());
if (decoded.length == 0) {
- throw new IllegalArgumentException("The 'block' value string does not contain any values");
+ throw new IllegalArgumentException("The block value string does not contain any values");
}
for (int i = 0; i < decoded.length; i++) {
values[i] = decoded[i];
}
} else {
- throw new IllegalArgumentException("Expected a block to contain a 'values' array");
+ throw new IllegalArgumentException("Expected a block to contain an array of values");
}
return values;
}
diff --git a/vespajlib/src/main/java/com/yahoo/text/JSON.java b/vespajlib/src/main/java/com/yahoo/text/JSON.java
index 6f8ef9a289f..8ef66b745cc 100644
--- a/vespajlib/src/main/java/com/yahoo/text/JSON.java
+++ b/vespajlib/src/main/java/com/yahoo/text/JSON.java
@@ -75,4 +75,8 @@ public final class JSON {
return leftSlime.equalTo(rightSlime);
}
+ public static String canonical(String jsonString) {
+ return SlimeUtils.jsonToSlimeOrThrow(jsonString).toString();
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 4692cf87d59..7f9705d33bd 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -3,7 +3,9 @@ package com.yahoo.tensor.serialization;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.text.JSON;
import org.junit.Test;
+import org.junit.jupiter.api.Assertions;
import java.nio.charset.StandardCharsets;
@@ -42,22 +44,6 @@ public class JsonFormatTestCase {
}
@Test
- public void testSparseTensor() {
- Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
- builder.cell().label("x", "a").label("y", "b").value(2.0);
- builder.cell().label("x", "c").label("y", "d").value(3.0);
- Tensor tensor = builder.build();
- byte[] json = JsonFormat.encode(tensor, false, false);
- assertEquals("{\"type\":\"tensor(x{},y{})\",\"cells\":[" +
- "{\"address\":{\"x\":\"a\",\"y\":\"b\"},\"value\":2.0}," +
- "{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":3.0}" +
- "]}",
- new String(json, StandardCharsets.UTF_8));
- Tensor decoded = JsonFormat.decode(tensor.type(), json);
- assertEquals(tensor, decoded);
- }
-
- @Test
public void testEmptySparseTensor() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
Tensor tensor = builder.build();
@@ -88,6 +74,45 @@ public class JsonFormatTestCase {
}
@Test
+ public void testEmptyTensor() {
+ Tensor tensor = Tensor.Builder.of(TensorType.empty).build();
+
+ String shortJson = """
+ {
+ "type":"tensor()",
+ "values":[0.0]
+ }
+ """;
+ byte[] shortEncoded = JsonFormat.encode(tensor, true, false);
+ assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded));
+
+ String longJson = """
+ {
+ "type":"tensor()",
+ "cells":[{"address":{},"value":0.0}]
+ }
+ """;
+ byte[] longEncoded = JsonFormat.encode(tensor, false, false);
+ assertEqualJson(longJson, new String(longEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longEncoded));
+
+ String shortDirectJson = """
+ [0.0]
+ """;
+ byte[] shortDirectEncoded = JsonFormat.encode(tensor, true, true);
+ assertEqualJson(shortDirectJson, new String(shortDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortDirectEncoded));
+
+ String longDirectJson = """
+ [{"address":{},"value":0.0}]
+ """;
+ byte[] longDirectEncoded = JsonFormat.encode(tensor, false, true);
+ assertEqualJson(longDirectJson, new String(longDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longDirectEncoded));
+ }
+
+ @Test
public void testDenseTensor() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[2])"));
builder.cell().label("x", 0).label("y", 0).value(2.0);
@@ -95,31 +120,183 @@ public class JsonFormatTestCase {
builder.cell().label("x", 1).label("y", 0).value(5.0);
builder.cell().label("x", 1).label("y", 1).value(7.0);
Tensor tensor = builder.build();
- byte[] json = JsonFormat.encode(tensor, false, false);
- assertEquals("{\"type\":\"tensor(x[2],y[2])\",\"cells\":[" +
- "{\"address\":{\"x\":\"0\",\"y\":\"0\"},\"value\":2.0}," +
- "{\"address\":{\"x\":\"0\",\"y\":\"1\"},\"value\":3.0}," +
- "{\"address\":{\"x\":\"1\",\"y\":\"0\"},\"value\":5.0}," +
- "{\"address\":{\"x\":\"1\",\"y\":\"1\"},\"value\":7.0}" +
- "]}",
- new String(json, StandardCharsets.UTF_8));
- Tensor decoded = JsonFormat.decode(tensor.type(), json);
- assertEquals(tensor, decoded);
+
+ String shortJson = """
+ {
+ "type":"tensor(x[2],y[2])",
+ "values":[[2.0,3.0],[5.0,7.0]]
+ }
+ """;
+ byte[] shortEncoded = JsonFormat.encode(tensor, true, false);
+ assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded));
+
+ String longJson = """
+ {
+ "type":"tensor(x[2],y[2])",
+ "cells":[
+ {"address":{"x":"0","y":"0"},"value":2.0},
+ {"address":{"x":"0","y":"1"},"value":3.0},
+ {"address":{"x":"1","y":"0"},"value":5.0},
+ {"address":{"x":"1","y":"1"},"value":7.0}
+ ]
+ }
+ """;
+ byte[] longEncoded = JsonFormat.encode(tensor, false, false);
+ assertEqualJson(longJson, new String(longEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longEncoded));
+
+ String shortDirectJson = """
+ [[2.0, 3.0], [5.0, 7.0]]
+ """;
+ byte[] shortDirectEncoded = JsonFormat.encode(tensor, true, true);
+ assertEqualJson(shortDirectJson, new String(shortDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortDirectEncoded));
+
+ String longDirectJson = """
+ [
+ {"address":{"x":"0","y":"0"},"value":2.0},
+ {"address":{"x":"0","y":"1"},"value":3.0},
+ {"address":{"x":"1","y":"0"},"value":5.0},
+ {"address":{"x":"1","y":"1"},"value":7.0}
+ ]
+ """;
+ byte[] longDirectEncoded = JsonFormat.encode(tensor, false, true);
+ assertEqualJson(longDirectJson, new String(longDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longDirectEncoded));
+ }
+
+ @Test
+ public void testMixedTensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[2])"));
+ builder.cell().label("x", "a").label("y", 0).value(2.0);
+ builder.cell().label("x", "a").label("y", 1).value(3.0);
+ builder.cell().label("x", "b").label("y", 0).value(5.0);
+ builder.cell().label("x", "b").label("y", 1).value(7.0);
+ Tensor tensor = builder.build();
+
+ String shortJson = """
+ {
+ "type":"tensor(x{},y[2])",
+ "blocks":{"a":[2.0,3.0],"b":[5.0,7.0]}
+ }
+ """;
+ byte[] shortEncoded = JsonFormat.encode(tensor, true, false);
+ assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded));
+
+ String longJson = """
+ {
+ "type":"tensor(x{},y[2])",
+ "cells":[
+ {"address":{"x":"a","y":"0"},"value":2.0},
+ {"address":{"x":"a","y":"1"},"value":3.0},
+ {"address":{"x":"b","y":"0"},"value":5.0},
+ {"address":{"x":"b","y":"1"},"value":7.0}
+ ]
+ }
+ """;
+ byte[] longEncoded = JsonFormat.encode(tensor, false, false);
+ assertEqualJson(longJson, new String(longEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longEncoded));
+
+ String shortDirectJson = """
+ {"a":[2.0,3.0],"b":[5.0,7.0]}
+ """;
+ byte[] shortDirectEncoded = JsonFormat.encode(tensor, true, true);
+ assertEqualJson(shortDirectJson, new String(shortDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortDirectEncoded));
+
+ String longDirectJson = """
+ [
+ {"address":{"x":"a","y":"0"},"value":2.0},
+ {"address":{"x":"a","y":"1"},"value":3.0},
+ {"address":{"x":"b","y":"0"},"value":5.0},
+ {"address":{"x":"b","y":"1"},"value":7.0}
+ ]
+ """;
+ byte[] longDirectEncoded = JsonFormat.encode(tensor, false, true);
+ assertEqualJson(longDirectJson, new String(longDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longDirectEncoded));
+ }
+
+ @Test
+ public void testSparseTensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
+ builder.cell().label("x", "a").label("y", 0).value(2.0);
+ builder.cell().label("x", "a").label("y", 1).value(3.0);
+ builder.cell().label("x", "b").label("y", 0).value(5.0);
+ builder.cell().label("x", "b").label("y", 1).value(7.0);
+ Tensor tensor = builder.build();
+
+ String shortJson = """
+ {
+ "type":"tensor(x{},y{})",
+ "cells": [
+ {"address":{"x":"a","y":"0"},"value":2.0},
+ {"address":{"x":"a","y":"1"},"value":3.0},
+ {"address":{"x":"b","y":"0"},"value":5.0},
+ {"address":{"x":"b","y":"1"},"value":7.0}
+ ]
+ }
+ """;
+ byte[] shortEncoded = JsonFormat.encode(tensor, true, false);
+ assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded));
+
+ String longJson = """
+ {
+ "type":"tensor(x{},y{})",
+ "cells":[
+ {"address":{"x":"a","y":"0"},"value":2.0},
+ {"address":{"x":"a","y":"1"},"value":3.0},
+ {"address":{"x":"b","y":"0"},"value":5.0},
+ {"address":{"x":"b","y":"1"},"value":7.0}
+ ]
+ }
+ """;
+ byte[] longEncoded = JsonFormat.encode(tensor, false, false);
+ assertEqualJson(longJson, new String(longEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longEncoded));
+
+ String shortDirectJson = """
+ [
+ {"address":{"x":"a","y":"0"},"value":2.0},
+ {"address":{"x":"a","y":"1"},"value":3.0},
+ {"address":{"x":"b","y":"0"},"value":5.0},
+ {"address":{"x":"b","y":"1"},"value":7.0}
+ ]
+ """;
+ byte[] shortDirectEncoded = JsonFormat.encode(tensor, true, true);
+ assertEqualJson(shortDirectJson, new String(shortDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortDirectEncoded));
+
+ String longDirectJson = """
+ [
+ {"address":{"x":"a","y":"0"},"value":2.0},
+ {"address":{"x":"a","y":"1"},"value":3.0},
+ {"address":{"x":"b","y":"0"},"value":5.0},
+ {"address":{"x":"b","y":"1"},"value":7.0}
+ ]
+ """;
+ byte[] longDirectEncoded = JsonFormat.encode(tensor, false, true);
+ assertEqualJson(longDirectJson, new String(longDirectEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), longDirectEncoded));
}
@Test
public void testDisallowedEmptyDenseTensor() {
TensorType type = TensorType.fromSpec("tensor(x[3])");
- assertDecodeFails(type, "{\"values\":[]}", "The 'values' array does not contain any values");
- assertDecodeFails(type, "{\"values\":\"\"}", "The 'values' string does not contain any values");
+ assertDecodeFails(type, "{\"values\":[]}", "The values array does not contain any values");
+ assertDecodeFails(type, "{\"values\":\"\"}", "The values string does not contain any values");
}
@Test
public void testDisallowedEmptyMixedTensor() {
TensorType type = TensorType.fromSpec("tensor(x{},y[3])");
- assertDecodeFails(type, "{\"blocks\":{ \"a\": [] } }", "The 'block' value array does not contain any values");
+ assertDecodeFails(type, "{\"blocks\":{ \"a\": [] } }", "The block value array does not contain any values");
assertDecodeFails(type, "{\"blocks\":[ {\"address\":{\"x\":\"a\"}, \"values\": [] } ] }",
- "The 'block' value array does not contain any values");
+ "The block value array does not contain any values");
}
@Test
@@ -426,8 +603,12 @@ public class JsonFormatTestCase {
Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8));
fail("Did not get exception as expected, decoded as: " + decoded);
} catch (IllegalArgumentException e) {
- assertEquals(e.getMessage(), msg);
+ assertEquals(msg, e.getMessage());
}
}
+ private void assertEqualJson(String expected, String generated) {
+ Assertions.assertEquals(JSON.canonical(expected), JSON.canonical(generated));
+ }
+
}