summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-06-18 12:39:21 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-06-18 12:39:21 +0200
commitdf72032ccaa6f2a132bba6a7c47b9885479b09de (patch)
tree5ccc1d3f7046fa4e24814534c7787c9e1b7fcde7 /vespajlib
parenta62b4b191fcdde36066e9ea362e4ea2dd4fd0114 (diff)
Deserialize dense form
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/lang/MutableInteger.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java38
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java17
3 files changed, 54 insertions, 7 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableInteger.java b/vespajlib/src/main/java/com/yahoo/lang/MutableInteger.java
index a988a3f6fa2..e2da62b6098 100644
--- a/vespajlib/src/main/java/com/yahoo/lang/MutableInteger.java
+++ b/vespajlib/src/main/java/com/yahoo/lang/MutableInteger.java
@@ -24,6 +24,12 @@ public class MutableInteger {
return value;
}
+ /** Increments the value by 1 and returns the value of this *before* incrementing */
+ public int next() {
+ value++;
+ return value - 1;
+ }
+
/** Adds the increment to the current value and returns the resulting value */
public int subtract(int increment) {
value -= increment;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 1a210a614cc..01da45c67aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
+import com.yahoo.lang.MutableInteger;
import com.yahoo.slime.ArrayTraverser;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Inspector;
@@ -8,6 +9,7 @@ import com.yahoo.slime.JsonDecoder;
import com.yahoo.slime.ObjectTraverser;
import com.yahoo.slime.Slime;
import com.yahoo.slime.Type;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -17,7 +19,9 @@ import java.util.Iterator;
/**
* Writes tensors on the JSON format used in Vespa tensor document fields:
* A JSON map containing a 'cells' array.
- * See http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor
+ * See a http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor
+ *
+ * @author bratseth
*/
public class JsonFormat {
@@ -55,13 +59,20 @@ public class JsonFormat {
/** Deserializes the given tensor from JSON format */
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
- Tensor.Builder tensorBuilder = Tensor.Builder.of(type);
+ Tensor.Builder builder = Tensor.Builder.of(type);
Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
- Inspector cells = root.field("cells");
+
+ if (root.field("cells").valid())
+ decodeCells(root.field("cells"), builder);
+ else if (root.field("values").valid())
+ decodeValues(root.field("values"), builder);
+ return builder.build();
+ }
+
+ private static void decodeCells(Inspector cells, Tensor.Builder builder) {
if ( cells.type() != Type.ARRAY)
- throw new IllegalArgumentException("Excepted an array item named 'cells' at the top level");
- cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell()));
- return tensorBuilder.build();
+ throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type());
+ cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder.cell()));
}
private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) {
@@ -76,4 +87,19 @@ public class JsonFormat {
cellBuilder.value(value.asDouble());
}
+ private static void decodeValues(Inspector values, Tensor.Builder builder) {
+ if ( ! (builder instanceof IndexedTensor.BoundBuilder))
+ throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
+ "Use 'cells' instead");
+ if ( values.type() != Type.ARRAY)
+ throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type());
+
+ IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
+ MutableInteger index = new MutableInteger(0);
+ values.traverse((ArrayTraverser) (__, value) -> {
+ if (value.type() != Type.LONG && value.type() != Type.DOUBLE)
+ throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type());
+ indexedBuilder.cellByDirectIndex(index.next(), value.asDouble());
+ });
+ }
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index b466307d3b9..4c44cbbf5c7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -33,7 +33,7 @@ public class JsonFormatTestCase {
@Test
public void testDenseTensor() {
- Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[2])"));
builder.cell().label("x", 0).label("y", 0).value(2.0);
builder.cell().label("x", 0).label("y", 1).value(3.0);
builder.cell().label("x", 1).label("y", 0).value(5.0);
@@ -52,6 +52,21 @@ public class JsonFormatTestCase {
}
@Test
+ public void testDenseTensorInDenseForm() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("x", 0).label("y", 2).value(4.0);
+ builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 1).value(6.0);
+ builder.cell().label("x", 1).label("y", 2).value(7.0);
+ Tensor expected = builder.build();
+ String denseJson = "{\"values\":[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]}";
+ Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8));
+ assertEquals(expected, decoded);
+ }
+
+ @Test
public void testTooManyCells() {
TensorType x2 = TensorType.fromSpec("tensor(x[2])");
String json = "{\"cells\":[" +