diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 38 |
1 files changed, 32 insertions, 6 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 1a210a614cc..01da45c67aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; +import com.yahoo.lang.MutableInteger; import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; @@ -8,6 +9,7 @@ import com.yahoo.slime.JsonDecoder; import com.yahoo.slime.ObjectTraverser; import com.yahoo.slime.Slime; import com.yahoo.slime.Type; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -17,7 +19,9 @@ import java.util.Iterator; /** * Writes tensors on the JSON format used in Vespa tensor document fields: * A JSON map containing a 'cells' array. - * See http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor + * See a http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor + * + * @author bratseth */ public class JsonFormat { @@ -55,13 +59,20 @@ public class JsonFormat { /** Deserializes the given tensor from JSON format */ public static Tensor decode(TensorType type, byte[] jsonTensorValue) { - Tensor.Builder tensorBuilder = Tensor.Builder.of(type); + Tensor.Builder builder = Tensor.Builder.of(type); Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get(); - Inspector cells = root.field("cells"); + + if (root.field("cells").valid()) + decodeCells(root.field("cells"), builder); + else if (root.field("values").valid()) + decodeValues(root.field("values"), builder); + return builder.build(); + } + + private static void decodeCells(Inspector cells, Tensor.Builder builder) { if ( cells.type() != Type.ARRAY) - throw new IllegalArgumentException("Excepted an array item named 'cells' at the top level"); - cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell())); - return tensorBuilder.build(); + throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type()); + cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder.cell())); } private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) { @@ -76,4 +87,19 @@ public class JsonFormat { cellBuilder.value(value.asDouble()); } + private static void decodeValues(Inspector values, Tensor.Builder builder) { + if ( ! (builder instanceof IndexedTensor.BoundBuilder)) + throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + + "Use 'cells' instead"); + if ( values.type() != Type.ARRAY) + throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type()); + + IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; + MutableInteger index = new MutableInteger(0); + values.traverse((ArrayTraverser) (__, value) -> { + if (value.type() != Type.LONG && value.type() != Type.DOUBLE) + throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type()); + indexedBuilder.cellByDirectIndex(index.next(), value.asDouble()); + }); + } } |