summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-06 15:47:07 +0100
committerGeir Storli <geirst@verizonmedia.com>2019-02-07 08:39:15 +0100
commitf0b7e22e881e26854473187498d409c5d36edfa2 (patch)
treeb905551ea2c3b993e4c10cee0e8882497e2adc8a
parent245a9611bce4d9d214ccb76016b67b6ca441dd24 (diff)
Support add update for sparse tensors in json reader.
A skeleton for TensorAddUpdate is added as well.
-rw-r--r--document/abi-spec.json21
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java43
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java8
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java68
-rw-r--r--document/src/main/java/com/yahoo/document/update/ValueUpdate.java3
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java51
6 files changed, 192 insertions, 2 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json
index a73383bab0d..7d7aad64bca 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -5181,6 +5181,26 @@
"protected com.yahoo.document.datatypes.FieldValue value"
]
},
+ "com.yahoo.document.update.TensorAddUpdate": {
+ "superClass": "com.yahoo.document.update.ValueUpdate",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.document.datatypes.TensorFieldValue)",
+ "protected void checkCompatibility(com.yahoo.document.DataType)",
+ "public void serialize(com.yahoo.document.serialization.DocumentUpdateWriter, com.yahoo.document.DataType)",
+ "public com.yahoo.document.datatypes.FieldValue applyTo(com.yahoo.document.datatypes.FieldValue)",
+ "public com.yahoo.document.datatypes.TensorFieldValue getValue()",
+ "public void setValue(com.yahoo.document.datatypes.TensorFieldValue)",
+ "public boolean equals(java.lang.Object)",
+ "public int hashCode()",
+ "public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)",
+ "public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()"
+ ],
+ "fields": []
+ },
"com.yahoo.document.update.TensorModifyUpdate$Operation": {
"superClass": "java.lang.Enum",
"interfaces": [],
@@ -5249,6 +5269,7 @@
"public static final enum com.yahoo.document.update.ValueUpdate$ValueUpdateClassID MAP",
"public static final enum com.yahoo.document.update.ValueUpdate$ValueUpdateClassID REMOVE",
"public static final enum com.yahoo.document.update.ValueUpdate$ValueUpdateClassID TENSORMODIFY",
+ "public static final enum com.yahoo.document.update.ValueUpdate$ValueUpdateClassID TENSORADD",
"public final int id",
"public final java.lang.String name"
]
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
new file mode 100644
index 00000000000..2cca11c19f8
--- /dev/null
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
@@ -0,0 +1,43 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.document.json.readers;
+
+import com.yahoo.document.Field;
+import com.yahoo.document.TensorDataType;
+import com.yahoo.document.datatypes.TensorFieldValue;
+import com.yahoo.document.json.TokenBuffer;
+import com.yahoo.document.update.TensorAddUpdate;
+import com.yahoo.tensor.TensorType;
+
+import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart;
+import static com.yahoo.document.json.readers.TensorReader.fillTensor;
+
+/**
+ * Class used to read an add update for a tensor field.
+ */
+public class TensorAddUpdateReader {
+
+ public static boolean isTensorField(Field field) {
+ return field.getDataType() instanceof TensorDataType;
+ }
+
+ public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) {
+ expectObjectStart(buffer.currentToken());
+ expectTensorTypeIsSparse(field);
+
+ TensorDataType tensorDataType = (TensorDataType)field.getDataType();
+ TensorType tensorType = tensorDataType.getTensorType();
+ TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType);
+ fillTensor(buffer, tensorFieldValue);
+ return new TensorAddUpdate(tensorFieldValue);
+ }
+
+ private static void expectTensorTypeIsSparse(Field field) {
+ TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
+ if (tensorType.dimensions().stream()
+ .anyMatch(dim -> dim.isIndexed())) {
+ throw new IllegalArgumentException("An add update can only be applied to sparse tensors. "
+ + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ }
+ }
+
+}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java b/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java
index cbdade6ad3c..c2b69062487 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java
@@ -30,6 +30,8 @@ import static com.yahoo.document.json.readers.MapReader.UPDATE_MATCH;
import static com.yahoo.document.json.readers.MapReader.createMapUpdate;
import static com.yahoo.document.json.readers.SingleValueReader.UPDATE_ASSIGN;
import static com.yahoo.document.json.readers.SingleValueReader.readSingleUpdate;
+import static com.yahoo.document.json.readers.TensorAddUpdateReader.createTensorAddUpdate;
+import static com.yahoo.document.json.readers.TensorAddUpdateReader.isTensorField;
import static com.yahoo.document.json.readers.TensorModifyUpdateReader.UPDATE_MODIFY;
import static com.yahoo.document.json.readers.TensorModifyUpdateReader.createModifyUpdate;
@@ -119,7 +121,11 @@ public class VespaJsonDocumentReader {
createRemoves(buffer, field, fieldUpdate);
break;
case UPDATE_ADD:
- createAdds(buffer, field, fieldUpdate);
+ if (isTensorField(field)) {
+ fieldUpdate.addValueUpdate(createTensorAddUpdate(buffer, field));
+ } else {
+ createAdds(buffer, field, fieldUpdate);
+ }
break;
case UPDATE_MATCH:
fieldUpdate.addValueUpdate(createMapUpdate(buffer, field));
diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
new file mode 100644
index 00000000000..3703ffc17a2
--- /dev/null
+++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
@@ -0,0 +1,68 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.document.update;
+
+import com.yahoo.document.DataType;
+import com.yahoo.document.TensorDataType;
+import com.yahoo.document.datatypes.FieldValue;
+import com.yahoo.document.datatypes.TensorFieldValue;
+import com.yahoo.document.serialization.DocumentUpdateWriter;
+
+import java.util.Objects;
+
+/**
+ * An update used to add cells to a sparse tensor (has only mapped dimensions).
+ *
+ * The cells to add are contained in a sparse tensor as well.
+ */
+public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
+
+ private TensorFieldValue tensor;
+
+ public TensorAddUpdate(TensorFieldValue tensor) {
+ super(ValueUpdateClassID.TENSORADD);
+ this.tensor = tensor;
+ }
+
+ @Override
+ protected void checkCompatibility(DataType fieldType) {
+ if (!(fieldType instanceof TensorDataType)) {
+ throw new UnsupportedOperationException("Expected tensor type, got " + fieldType.getName() + ".");
+ }
+ }
+
+ @Override
+ public void serialize(DocumentUpdateWriter data, DataType superType) {
+ // TODO: implement
+ throw new UnsupportedOperationException("Not implemented yet");
+ }
+
+ @Override
+ public FieldValue applyTo(FieldValue oldValue) {
+ // TODO: implement
+ return null;
+ }
+
+ @Override
+ public TensorFieldValue getValue() {
+ return tensor;
+ }
+
+ @Override
+ public void setValue(TensorFieldValue value) {
+ tensor = value;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ if (!super.equals(o)) return false;
+ TensorAddUpdate that = (TensorAddUpdate) o;
+ return tensor.equals(that.tensor);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), tensor);
+ }
+}
diff --git a/document/src/main/java/com/yahoo/document/update/ValueUpdate.java b/document/src/main/java/com/yahoo/document/update/ValueUpdate.java
index 9600b820549..4e4dae60589 100644
--- a/document/src/main/java/com/yahoo/document/update/ValueUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/ValueUpdate.java
@@ -343,7 +343,8 @@ public abstract class ValueUpdate<T extends FieldValue> {
CLEAR(28, "clear"),
MAP(29, "map"),
REMOVE(30, "remove"),
- TENSORMODIFY(100, "tensormodify");
+ TENSORMODIFY(100, "tensormodify"),
+ TENSORADD(101, "tensoradd");
public final int id;
public final String name;
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 2588d56e24f..ecbd615a633 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -43,6 +43,7 @@ import com.yahoo.document.update.AssignValueUpdate;
import com.yahoo.document.update.ClearValueUpdate;
import com.yahoo.document.update.FieldUpdate;
import com.yahoo.document.update.MapValueUpdate;
+import com.yahoo.document.update.TensorAddUpdate;
import com.yahoo.document.update.TensorModifyUpdate;
import com.yahoo.document.update.ValueUpdate;
import com.yahoo.io.GrowableByteBuffer;
@@ -1428,6 +1429,32 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_add_update_on_sparse_tensor() {
+ assertTensorAddUpdate("{{x:a,y:b}:2.0, {x:c,y:d}: 3.0}", "sparse_tensor",
+ inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 },",
+ " { 'address': { 'x': 'c', 'y': 'd' }, 'value': 3.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_add_update_on_non_sparse_tensor_throws() {
+ exception.expect(IllegalArgumentException.class);
+ exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'");
+ createTensorAddUpdate(inputJson("{",
+ " 'cells': [] }"), "mixed_tensor");
+ }
+
+ @Test
+ public void tensor_add_update_on_not_fully_specified_cell_throws() {
+ exception.expect(IllegalArgumentException.class);
+ exception.expectMessage("Missing a value for dimension y for tensor(x{},y{})");
+ createTensorAddUpdate(inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': 'a' }, 'value': 2.0 } ]}"), "sparse_tensor");
+ }
+
+ @Test
public void require_that_parser_propagates_datatype_parser_errors_predicate() {
assertParserErrorMatches(
"Error in document 'id:unittest:testpredicate::0' - could not parse field 'boolean' of type 'predicate': " +
@@ -1573,6 +1600,30 @@ public class JsonReaderTestCase {
assertEquals(Tensor.from(expectedTensor), modifyUpdate.getValue().getTensor().get());
}
+ private DocumentUpdate createTensorAddUpdate(String tensorJson, String tensorFieldName) {
+ JsonReader reader = createReader(inputJson("[",
+ "{ 'update': '" + TENSOR_DOC_ID + "',",
+ " 'fields': {",
+ " '" + tensorFieldName + "': {",
+ " 'add': " + tensorJson + " }}}]"));
+ return (DocumentUpdate) reader.next();
+ }
+
+ private void assertTensorAddUpdate(String expectedTensor, String tensorFieldName, String tensorJson) {
+ assertTensorAddUpdate(expectedTensor, tensorFieldName,
+ createTensorAddUpdate(tensorJson, tensorFieldName));
+ }
+
+ private static void assertTensorAddUpdate(String expectedTensor, String tensorFieldName, DocumentUpdate update) {
+ assertEquals("testtensor", update.getId().getDocType());
+ assertEquals(TENSOR_DOC_ID, update.getId().toString());
+ assertEquals(1, update.fieldUpdates().size());
+ FieldUpdate fieldUpdate = update.getFieldUpdate(tensorFieldName);
+ assertEquals(1, fieldUpdate.size());
+ TensorAddUpdate addUpdate = (TensorAddUpdate) fieldUpdate.getValueUpdate(0);
+ assertEquals(Tensor.from(expectedTensor), addUpdate.getValue().getTensor().get());
+ }
+
private static FieldUpdate getTensorField(DocumentUpdate update) {
FieldUpdate fieldUpdate = update.getFieldUpdate("sparse_tensor");
assertEquals(1, fieldUpdate.size());