summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java59
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java4
4 files changed, 78 insertions, 14 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index a3d2a157073..6bdac611fdc 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -3,6 +3,10 @@ package com.yahoo.document.json.readers;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
+import com.yahoo.lang.MutableInteger;
+import com.yahoo.slime.ArrayTraverser;
+import com.yahoo.slime.Type;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
@@ -11,54 +15,61 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.*;
/**
* Reads the tensor format described at
* http://docs.vespa.ai/documentation/reference/document-json-format.html#tensor
+ *
+ * @author geirst
+ * @author bratseth
*/
public class TensorReader {
public static final String TENSOR_ADDRESS = "address";
public static final String TENSOR_DIMENSIONS = "dimensions";
public static final String TENSOR_CELLS = "cells";
+ public static final String TENSOR_VALUES = "values";
public static final String TENSOR_VALUE = "value";
- public static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) {
+ static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) {
// TODO: Switch implementation to om.yahoo.tensor.serialization.JsonFormat.decode
- Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType());
+ Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType());
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
- // read tensor cell fields and ignore everything else
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
- if (TensorReader.TENSOR_CELLS.equals(buffer.currentName()))
- readTensorCells(buffer, tensorBuilder);
+ if (TENSOR_CELLS.equals(buffer.currentName()))
+ readTensorCells(buffer, builder);
+ else if (TENSOR_VALUES.equals(buffer.currentName()))
+ readTensorValues(buffer, builder);
+ else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
+ throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'");
}
expectObjectEnd(buffer.currentToken());
- tensorFieldValue.assign(tensorBuilder.build());
+ tensorFieldValue.assign(builder.build());
}
- public static void readTensorCells(TokenBuffer buffer, Tensor.Builder tensorBuilder) {
+ static void readTensorCells(TokenBuffer buffer, Tensor.Builder builder) {
expectArrayStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
- readTensorCell(buffer, tensorBuilder);
+ readTensorCell(buffer, builder);
expectCompositeEnd(buffer.currentToken());
}
- public static void readTensorCell(TokenBuffer buffer, Tensor.Builder tensorBuilder) {
+ private static void readTensorCell(TokenBuffer buffer, Tensor.Builder builder) {
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
double cellValue = 0.0;
- Tensor.Builder.CellBuilder cellBuilder = tensorBuilder.cell();
+ Tensor.Builder.CellBuilder cellBuilder = builder.cell();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
String currentName = buffer.currentName();
if (TensorReader.TENSOR_ADDRESS.equals(currentName)) {
readTensorAddress(buffer, cellBuilder);
} else if (TensorReader.TENSOR_VALUE.equals(currentName)) {
- cellValue = Double.valueOf(buffer.currentText());
+ cellValue = readDouble(buffer);
}
}
expectObjectEnd(buffer.currentToken());
cellBuilder.value(cellValue);
}
- public static void readTensorAddress(TokenBuffer buffer, MappedTensor.Builder.CellBuilder cellBuilder) {
+ private static void readTensorAddress(TokenBuffer buffer, MappedTensor.Builder.CellBuilder cellBuilder) {
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
@@ -68,4 +79,28 @@ public class TensorReader {
}
expectObjectEnd(buffer.currentToken());
}
+
+ private static void readTensorValues(TokenBuffer buffer, Tensor.Builder builder) {
+ if ( ! (builder instanceof IndexedTensor.BoundBuilder))
+ throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
+ "Use 'cells' instead");
+ expectArrayStart(buffer.currentToken());
+
+ IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
+ int index = 0;
+ int initNesting = buffer.nesting();
+ for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
+ indexedBuilder.cellByDirectIndex(index++, readDouble(buffer));
+ expectCompositeEnd(buffer.currentToken());
+ }
+
+ private static double readDouble(TokenBuffer buffer) {
+ try {
+ return Double.valueOf(buffer.currentText());
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Expected a number but got '" + buffer.currentText());
+ }
+ }
+
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index f8ee23e86ba..69be397595e 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -52,6 +52,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.text.Utf8;
import org.junit.After;
import org.junit.Before;
@@ -63,6 +64,7 @@ import org.mockito.internal.matchers.Contains;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
@@ -1294,6 +1296,24 @@ public class JsonReaderTestCase {
}
@Test
+ public void testParsingOfDenseTensorOnDenseForm() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("x", 0).label("y", 2).value(4.0);
+ builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 1).value(6.0);
+ builder.cell().label("x", 1).label("y", 2).value(7.0);
+ Tensor expected = builder.build();
+
+ Tensor tensor = assertTensorField(expected,
+ createPutWithTensor(inputJson("{",
+ " 'values': [2.0, 3.0, 4.0, 5.0, 6.0, 7.0]",
+ "}"), "dense_tensor"), "dense_tensor");
+ assertTrue(tensor instanceof IndexedTensor); // this matters for performance
+ }
+
+ @Test
public void testParsingOfTensorWithSingleCellInDifferentJsonOrder() {
assertSparseTensorField("{{x:a,y:b}:2.0}",
createPutWithSparseTensor(inputJson("{",
@@ -1689,11 +1709,14 @@ public class JsonReaderTestCase {
return assertTensorField(expectedTensor, put, "sparse_tensor");
}
private static Tensor assertTensorField(String expectedTensor, DocumentPut put, String tensorFieldName) {
- final Document doc = put.getDocument();
+ return assertTensorField(Tensor.from(expectedTensor), put, tensorFieldName);
+ }
+ private static Tensor assertTensorField(Tensor expectedTensor, DocumentPut put, String tensorFieldName) {
+ Document doc = put.getDocument();
assertEquals("testtensor", doc.getId().getDocType());
assertEquals(TENSOR_DOC_ID, doc.getId().toString());
TensorFieldValue fieldValue = (TensorFieldValue)doc.getFieldValue(doc.getField(tensorFieldName));
- assertEquals(Tensor.from(expectedTensor), fieldValue.getTensor().get());
+ assertEquals(expectedTensor, fieldValue.getTensor().get());
return fieldValue.getTensor().get();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index aca2bfc1b0f..bc351b45b28 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -382,8 +382,10 @@ public abstract class IndexedTensor implements Tensor {
DimensionSizes sizes() { return sizes; }
+ /** Sets a value by its right-adjacent traversal position */
public abstract void cellByDirectIndex(long index, double value);
+ /** Sets a value by its right-adjacent traversal position */
public abstract void cellByDirectIndex(long index, float value);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 01da45c67aa..c73ff03a0eb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -58,6 +58,7 @@ public class JsonFormat {
}
/** Deserializes the given tensor from JSON format */
+ // NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
Tensor.Builder builder = Tensor.Builder.of(type);
Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
@@ -66,6 +67,8 @@ public class JsonFormat {
decodeCells(root.field("cells"), builder);
else if (root.field("values").valid())
decodeValues(root.field("values"), builder);
+ else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
+ throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'");
return builder.build();
}
@@ -102,4 +105,5 @@ public class JsonFormat {
indexedBuilder.cellByDirectIndex(index.next(), value.asDouble());
});
}
+
}