summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-06-18 13:46:08 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-06-18 13:46:08 +0200
commitf6d3249d38f07a4526b2343ada38564d762d50d7 (patch)
tree3d6f9d5d420a84d7b76b7241b624552c968501a7 /document
parentdf72032ccaa6f2a132bba6a7c47b9885479b09de (diff)
Read dense tensor form in documents
Diffstat (limited to 'document')
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java59
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java27
2 files changed, 72 insertions, 14 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index a3d2a157073..6bdac611fdc 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -3,6 +3,10 @@ package com.yahoo.document.json.readers;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
+import com.yahoo.lang.MutableInteger;
+import com.yahoo.slime.ArrayTraverser;
+import com.yahoo.slime.Type;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
@@ -11,54 +15,61 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.*;
/**
* Reads the tensor format described at
* http://docs.vespa.ai/documentation/reference/document-json-format.html#tensor
+ *
+ * @author geirst
+ * @author bratseth
*/
public class TensorReader {
public static final String TENSOR_ADDRESS = "address";
public static final String TENSOR_DIMENSIONS = "dimensions";
public static final String TENSOR_CELLS = "cells";
+ public static final String TENSOR_VALUES = "values";
public static final String TENSOR_VALUE = "value";
- public static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) {
+ static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) {
// TODO: Switch implementation to om.yahoo.tensor.serialization.JsonFormat.decode
- Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType());
+ Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType());
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
- // read tensor cell fields and ignore everything else
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
- if (TensorReader.TENSOR_CELLS.equals(buffer.currentName()))
- readTensorCells(buffer, tensorBuilder);
+ if (TENSOR_CELLS.equals(buffer.currentName()))
+ readTensorCells(buffer, builder);
+ else if (TENSOR_VALUES.equals(buffer.currentName()))
+ readTensorValues(buffer, builder);
+ else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
+ throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'");
}
expectObjectEnd(buffer.currentToken());
- tensorFieldValue.assign(tensorBuilder.build());
+ tensorFieldValue.assign(builder.build());
}
- public static void readTensorCells(TokenBuffer buffer, Tensor.Builder tensorBuilder) {
+ static void readTensorCells(TokenBuffer buffer, Tensor.Builder builder) {
expectArrayStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
- readTensorCell(buffer, tensorBuilder);
+ readTensorCell(buffer, builder);
expectCompositeEnd(buffer.currentToken());
}
- public static void readTensorCell(TokenBuffer buffer, Tensor.Builder tensorBuilder) {
+ private static void readTensorCell(TokenBuffer buffer, Tensor.Builder builder) {
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
double cellValue = 0.0;
- Tensor.Builder.CellBuilder cellBuilder = tensorBuilder.cell();
+ Tensor.Builder.CellBuilder cellBuilder = builder.cell();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
String currentName = buffer.currentName();
if (TensorReader.TENSOR_ADDRESS.equals(currentName)) {
readTensorAddress(buffer, cellBuilder);
} else if (TensorReader.TENSOR_VALUE.equals(currentName)) {
- cellValue = Double.valueOf(buffer.currentText());
+ cellValue = readDouble(buffer);
}
}
expectObjectEnd(buffer.currentToken());
cellBuilder.value(cellValue);
}
- public static void readTensorAddress(TokenBuffer buffer, MappedTensor.Builder.CellBuilder cellBuilder) {
+ private static void readTensorAddress(TokenBuffer buffer, MappedTensor.Builder.CellBuilder cellBuilder) {
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
@@ -68,4 +79,28 @@ public class TensorReader {
}
expectObjectEnd(buffer.currentToken());
}
+
+ private static void readTensorValues(TokenBuffer buffer, Tensor.Builder builder) {
+ if ( ! (builder instanceof IndexedTensor.BoundBuilder))
+ throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
+ "Use 'cells' instead");
+ expectArrayStart(buffer.currentToken());
+
+ IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
+ int index = 0;
+ int initNesting = buffer.nesting();
+ for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
+ indexedBuilder.cellByDirectIndex(index++, readDouble(buffer));
+ expectCompositeEnd(buffer.currentToken());
+ }
+
+ private static double readDouble(TokenBuffer buffer) {
+ try {
+ return Double.valueOf(buffer.currentText());
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Expected a number but got '" + buffer.currentText());
+ }
+ }
+
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index f8ee23e86ba..69be397595e 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -52,6 +52,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.text.Utf8;
import org.junit.After;
import org.junit.Before;
@@ -63,6 +64,7 @@ import org.mockito.internal.matchers.Contains;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
@@ -1294,6 +1296,24 @@ public class JsonReaderTestCase {
}
@Test
+ public void testParsingOfDenseTensorOnDenseForm() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("x", 0).label("y", 2).value(4.0);
+ builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 1).value(6.0);
+ builder.cell().label("x", 1).label("y", 2).value(7.0);
+ Tensor expected = builder.build();
+
+ Tensor tensor = assertTensorField(expected,
+ createPutWithTensor(inputJson("{",
+ " 'values': [2.0, 3.0, 4.0, 5.0, 6.0, 7.0]",
+ "}"), "dense_tensor"), "dense_tensor");
+ assertTrue(tensor instanceof IndexedTensor); // this matters for performance
+ }
+
+ @Test
public void testParsingOfTensorWithSingleCellInDifferentJsonOrder() {
assertSparseTensorField("{{x:a,y:b}:2.0}",
createPutWithSparseTensor(inputJson("{",
@@ -1689,11 +1709,14 @@ public class JsonReaderTestCase {
return assertTensorField(expectedTensor, put, "sparse_tensor");
}
private static Tensor assertTensorField(String expectedTensor, DocumentPut put, String tensorFieldName) {
- final Document doc = put.getDocument();
+ return assertTensorField(Tensor.from(expectedTensor), put, tensorFieldName);
+ }
+ private static Tensor assertTensorField(Tensor expectedTensor, DocumentPut put, String tensorFieldName) {
+ Document doc = put.getDocument();
assertEquals("testtensor", doc.getId().getDocType());
assertEquals(TENSOR_DOC_ID, doc.getId().toString());
TensorFieldValue fieldValue = (TensorFieldValue)doc.getFieldValue(doc.getField(tensorFieldName));
- assertEquals(Tensor.from(expectedTensor), fieldValue.getTensor().get());
+ assertEquals(expectedTensor, fieldValue.getTensor().get());
return fieldValue.getTensor().get();
}