summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java38
1 files changed, 32 insertions, 6 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 1a210a614cc..01da45c67aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
+import com.yahoo.lang.MutableInteger;
import com.yahoo.slime.ArrayTraverser;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Inspector;
@@ -8,6 +9,7 @@ import com.yahoo.slime.JsonDecoder;
import com.yahoo.slime.ObjectTraverser;
import com.yahoo.slime.Slime;
import com.yahoo.slime.Type;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -17,7 +19,9 @@ import java.util.Iterator;
/**
* Writes tensors on the JSON format used in Vespa tensor document fields:
* A JSON map containing a 'cells' array.
- * See http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor
+ * See a http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor
+ *
+ * @author bratseth
*/
public class JsonFormat {
@@ -55,13 +59,20 @@ public class JsonFormat {
/** Deserializes the given tensor from JSON format */
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
- Tensor.Builder tensorBuilder = Tensor.Builder.of(type);
+ Tensor.Builder builder = Tensor.Builder.of(type);
Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
- Inspector cells = root.field("cells");
+
+ if (root.field("cells").valid())
+ decodeCells(root.field("cells"), builder);
+ else if (root.field("values").valid())
+ decodeValues(root.field("values"), builder);
+ return builder.build();
+ }
+
+ private static void decodeCells(Inspector cells, Tensor.Builder builder) {
if ( cells.type() != Type.ARRAY)
- throw new IllegalArgumentException("Excepted an array item named 'cells' at the top level");
- cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell()));
- return tensorBuilder.build();
+ throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type());
+ cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder.cell()));
}
private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) {
@@ -76,4 +87,19 @@ public class JsonFormat {
cellBuilder.value(value.asDouble());
}
+ private static void decodeValues(Inspector values, Tensor.Builder builder) {
+ if ( ! (builder instanceof IndexedTensor.BoundBuilder))
+ throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
+ "Use 'cells' instead");
+ if ( values.type() != Type.ARRAY)
+ throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type());
+
+ IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
+ MutableInteger index = new MutableInteger(0);
+ values.traverse((ArrayTraverser) (__, value) -> {
+ if (value.type() != Type.LONG && value.type() != Type.DOUBLE)
+ throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type());
+ indexedBuilder.cellByDirectIndex(index.next(), value.asDouble());
+ });
+ }
}