diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-18 13:46:08 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-18 13:46:08 +0200 |
commit | f6d3249d38f07a4526b2343ada38564d762d50d7 (patch) | |
tree | 3d6f9d5d420a84d7b76b7241b624552c968501a7 /document | |
parent | df72032ccaa6f2a132bba6a7c47b9885479b09de (diff) |
Read dense tensor form in documents
Diffstat (limited to 'document')
-rw-r--r-- | document/src/main/java/com/yahoo/document/json/readers/TensorReader.java | 59 | ||||
-rw-r--r-- | document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java | 27 |
2 files changed, 72 insertions, 14 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index a3d2a157073..6bdac611fdc 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -3,6 +3,10 @@ package com.yahoo.document.json.readers; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; +import com.yahoo.lang.MutableInteger; +import com.yahoo.slime.ArrayTraverser; +import com.yahoo.slime.Type; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; @@ -11,54 +15,61 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.*; /** * Reads the tensor format described at * http://docs.vespa.ai/documentation/reference/document-json-format.html#tensor + * + * @author geirst + * @author bratseth */ public class TensorReader { public static final String TENSOR_ADDRESS = "address"; public static final String TENSOR_DIMENSIONS = "dimensions"; public static final String TENSOR_CELLS = "cells"; + public static final String TENSOR_VALUES = "values"; public static final String TENSOR_VALUE = "value"; - public static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) { + static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) { // TODO: Switch implementation to om.yahoo.tensor.serialization.JsonFormat.decode - Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType()); + Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType()); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); - // read tensor cell fields and ignore everything else for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { - if (TensorReader.TENSOR_CELLS.equals(buffer.currentName())) - readTensorCells(buffer, tensorBuilder); + if (TENSOR_CELLS.equals(buffer.currentName())) + readTensorCells(buffer, builder); + else if (TENSOR_VALUES.equals(buffer.currentName())) + readTensorValues(buffer, builder); + else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty + throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'"); } expectObjectEnd(buffer.currentToken()); - tensorFieldValue.assign(tensorBuilder.build()); + tensorFieldValue.assign(builder.build()); } - public static void readTensorCells(TokenBuffer buffer, Tensor.Builder tensorBuilder) { + static void readTensorCells(TokenBuffer buffer, Tensor.Builder builder) { expectArrayStart(buffer.currentToken()); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) - readTensorCell(buffer, tensorBuilder); + readTensorCell(buffer, builder); expectCompositeEnd(buffer.currentToken()); } - public static void readTensorCell(TokenBuffer buffer, Tensor.Builder tensorBuilder) { + private static void readTensorCell(TokenBuffer buffer, Tensor.Builder builder) { expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); double cellValue = 0.0; - Tensor.Builder.CellBuilder cellBuilder = tensorBuilder.cell(); + Tensor.Builder.CellBuilder cellBuilder = builder.cell(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { String currentName = buffer.currentName(); if (TensorReader.TENSOR_ADDRESS.equals(currentName)) { readTensorAddress(buffer, cellBuilder); } else if (TensorReader.TENSOR_VALUE.equals(currentName)) { - cellValue = Double.valueOf(buffer.currentText()); + cellValue = readDouble(buffer); } } expectObjectEnd(buffer.currentToken()); cellBuilder.value(cellValue); } - public static void readTensorAddress(TokenBuffer buffer, MappedTensor.Builder.CellBuilder cellBuilder) { + private static void readTensorAddress(TokenBuffer buffer, MappedTensor.Builder.CellBuilder cellBuilder) { expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { @@ -68,4 +79,28 @@ public class TensorReader { } expectObjectEnd(buffer.currentToken()); } + + private static void readTensorValues(TokenBuffer buffer, Tensor.Builder builder) { + if ( ! (builder instanceof IndexedTensor.BoundBuilder)) + throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + + "Use 'cells' instead"); + expectArrayStart(buffer.currentToken()); + + IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; + int index = 0; + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + indexedBuilder.cellByDirectIndex(index++, readDouble(buffer)); + expectCompositeEnd(buffer.currentToken()); + } + + private static double readDouble(TokenBuffer buffer) { + try { + return Double.valueOf(buffer.currentText()); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Expected a number but got '" + buffer.currentText()); + } + } + } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index f8ee23e86ba..69be397595e 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -52,6 +52,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.text.Utf8; import org.junit.After; import org.junit.Before; @@ -63,6 +64,7 @@ import org.mockito.internal.matchers.Contains; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Base64; import java.util.Collections; @@ -1294,6 +1296,24 @@ public class JsonReaderTestCase { } @Test + public void testParsingOfDenseTensorOnDenseForm() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 0).label("y", 2).value(4.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(6.0); + builder.cell().label("x", 1).label("y", 2).value(7.0); + Tensor expected = builder.build(); + + Tensor tensor = assertTensorField(expected, + createPutWithTensor(inputJson("{", + " 'values': [2.0, 3.0, 4.0, 5.0, 6.0, 7.0]", + "}"), "dense_tensor"), "dense_tensor"); + assertTrue(tensor instanceof IndexedTensor); // this matters for performance + } + + @Test public void testParsingOfTensorWithSingleCellInDifferentJsonOrder() { assertSparseTensorField("{{x:a,y:b}:2.0}", createPutWithSparseTensor(inputJson("{", @@ -1689,11 +1709,14 @@ public class JsonReaderTestCase { return assertTensorField(expectedTensor, put, "sparse_tensor"); } private static Tensor assertTensorField(String expectedTensor, DocumentPut put, String tensorFieldName) { - final Document doc = put.getDocument(); + return assertTensorField(Tensor.from(expectedTensor), put, tensorFieldName); + } + private static Tensor assertTensorField(Tensor expectedTensor, DocumentPut put, String tensorFieldName) { + Document doc = put.getDocument(); assertEquals("testtensor", doc.getId().getDocType()); assertEquals(TENSOR_DOC_ID, doc.getId().toString()); TensorFieldValue fieldValue = (TensorFieldValue)doc.getFieldValue(doc.getField(tensorFieldName)); - assertEquals(Tensor.from(expectedTensor), fieldValue.getTensor().get()); + assertEquals(expectedTensor, fieldValue.getTensor().get()); return fieldValue.getTensor().get(); } |