diff options
Diffstat (limited to 'vespajlib')
3 files changed, 54 insertions, 7 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableInteger.java b/vespajlib/src/main/java/com/yahoo/lang/MutableInteger.java index a988a3f6fa2..e2da62b6098 100644 --- a/vespajlib/src/main/java/com/yahoo/lang/MutableInteger.java +++ b/vespajlib/src/main/java/com/yahoo/lang/MutableInteger.java @@ -24,6 +24,12 @@ public class MutableInteger { return value; } + /** Increments the value by 1 and returns the value of this *before* incrementing */ + public int next() { + value++; + return value - 1; + } + /** Adds the increment to the current value and returns the resulting value */ public int subtract(int increment) { value -= increment; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 1a210a614cc..01da45c67aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; +import com.yahoo.lang.MutableInteger; import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; @@ -8,6 +9,7 @@ import com.yahoo.slime.JsonDecoder; import com.yahoo.slime.ObjectTraverser; import com.yahoo.slime.Slime; import com.yahoo.slime.Type; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -17,7 +19,9 @@ import java.util.Iterator; /** * Writes tensors on the JSON format used in Vespa tensor document fields: * A JSON map containing a 'cells' array. - * See http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor + * See a http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor + * + * @author bratseth */ public class JsonFormat { @@ -55,13 +59,20 @@ public class JsonFormat { /** Deserializes the given tensor from JSON format */ public static Tensor decode(TensorType type, byte[] jsonTensorValue) { - Tensor.Builder tensorBuilder = Tensor.Builder.of(type); + Tensor.Builder builder = Tensor.Builder.of(type); Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get(); - Inspector cells = root.field("cells"); + + if (root.field("cells").valid()) + decodeCells(root.field("cells"), builder); + else if (root.field("values").valid()) + decodeValues(root.field("values"), builder); + return builder.build(); + } + + private static void decodeCells(Inspector cells, Tensor.Builder builder) { if ( cells.type() != Type.ARRAY) - throw new IllegalArgumentException("Excepted an array item named 'cells' at the top level"); - cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell())); - return tensorBuilder.build(); + throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type()); + cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder.cell())); } private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) { @@ -76,4 +87,19 @@ public class JsonFormat { cellBuilder.value(value.asDouble()); } + private static void decodeValues(Inspector values, Tensor.Builder builder) { + if ( ! (builder instanceof IndexedTensor.BoundBuilder)) + throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + + "Use 'cells' instead"); + if ( values.type() != Type.ARRAY) + throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type()); + + IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; + MutableInteger index = new MutableInteger(0); + values.traverse((ArrayTraverser) (__, value) -> { + if (value.type() != Type.LONG && value.type() != Type.DOUBLE) + throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type()); + indexedBuilder.cellByDirectIndex(index.next(), value.asDouble()); + }); + } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index b466307d3b9..4c44cbbf5c7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -33,7 +33,7 @@ public class JsonFormatTestCase { @Test public void testDenseTensor() { - Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[2])")); builder.cell().label("x", 0).label("y", 0).value(2.0); builder.cell().label("x", 0).label("y", 1).value(3.0); builder.cell().label("x", 1).label("y", 0).value(5.0); @@ -52,6 +52,21 @@ public class JsonFormatTestCase { } @Test + public void testDenseTensorInDenseForm() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 0).label("y", 2).value(4.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(6.0); + builder.cell().label("x", 1).label("y", 2).value(7.0); + Tensor expected = builder.build(); + String denseJson = "{\"values\":[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]}"; + Tensor decoded = JsonFormat.decode(expected.type(), denseJson.getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, decoded); + } + + @Test public void testTooManyCells() { TensorType x2 = TensorType.fromSpec("tensor(x[2])"); String json = "{\"cells\":[" + |