diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-02-06 15:47:07 +0100 |
---|---|---|
committer | Geir Storli <geirst@verizonmedia.com> | 2019-02-07 08:39:15 +0100 |
commit | f0b7e22e881e26854473187498d409c5d36edfa2 (patch) | |
tree | b905551ea2c3b993e4c10cee0e8882497e2adc8a | |
parent | 245a9611bce4d9d214ccb76016b67b6ca441dd24 (diff) |
Support add update for sparse tensors in json reader.
A skeleton for TensorAddUpdate is added as well.
6 files changed, 192 insertions, 2 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json index a73383bab0d..7d7aad64bca 100644 --- a/document/abi-spec.json +++ b/document/abi-spec.json @@ -5181,6 +5181,26 @@ "protected com.yahoo.document.datatypes.FieldValue value" ] }, + "com.yahoo.document.update.TensorAddUpdate": { + "superClass": "com.yahoo.document.update.ValueUpdate", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(com.yahoo.document.datatypes.TensorFieldValue)", + "protected void checkCompatibility(com.yahoo.document.DataType)", + "public void serialize(com.yahoo.document.serialization.DocumentUpdateWriter, com.yahoo.document.DataType)", + "public com.yahoo.document.datatypes.FieldValue applyTo(com.yahoo.document.datatypes.FieldValue)", + "public com.yahoo.document.datatypes.TensorFieldValue getValue()", + "public void setValue(com.yahoo.document.datatypes.TensorFieldValue)", + "public boolean equals(java.lang.Object)", + "public int hashCode()", + "public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)", + "public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()" + ], + "fields": [] + }, "com.yahoo.document.update.TensorModifyUpdate$Operation": { "superClass": "java.lang.Enum", "interfaces": [], @@ -5249,6 +5269,7 @@ "public static final enum com.yahoo.document.update.ValueUpdate$ValueUpdateClassID MAP", "public static final enum com.yahoo.document.update.ValueUpdate$ValueUpdateClassID REMOVE", "public static final enum com.yahoo.document.update.ValueUpdate$ValueUpdateClassID TENSORMODIFY", + "public static final enum com.yahoo.document.update.ValueUpdate$ValueUpdateClassID TENSORADD", "public final int id", "public final java.lang.String name" ] diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java new file mode 100644 index 00000000000..2cca11c19f8 --- /dev/null +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java @@ -0,0 +1,43 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.document.json.readers; + +import com.yahoo.document.Field; +import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.document.json.TokenBuffer; +import com.yahoo.document.update.TensorAddUpdate; +import com.yahoo.tensor.TensorType; + +import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart; +import static com.yahoo.document.json.readers.TensorReader.fillTensor; + +/** + * Class used to read an add update for a tensor field. + */ +public class TensorAddUpdateReader { + + public static boolean isTensorField(Field field) { + return field.getDataType() instanceof TensorDataType; + } + + public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) { + expectObjectStart(buffer.currentToken()); + expectTensorTypeIsSparse(field); + + TensorDataType tensorDataType = (TensorDataType)field.getDataType(); + TensorType tensorType = tensorDataType.getTensorType(); + TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType); + fillTensor(buffer, tensorFieldValue); + return new TensorAddUpdate(tensorFieldValue); + } + + private static void expectTensorTypeIsSparse(Field field) { + TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); + if (tensorType.dimensions().stream() + .anyMatch(dim -> dim.isIndexed())) { + throw new IllegalArgumentException("An add update can only be applied to sparse tensors. " + + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + } + } + +} diff --git a/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java b/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java index cbdade6ad3c..c2b69062487 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/VespaJsonDocumentReader.java @@ -30,6 +30,8 @@ import static com.yahoo.document.json.readers.MapReader.UPDATE_MATCH; import static com.yahoo.document.json.readers.MapReader.createMapUpdate; import static com.yahoo.document.json.readers.SingleValueReader.UPDATE_ASSIGN; import static com.yahoo.document.json.readers.SingleValueReader.readSingleUpdate; +import static com.yahoo.document.json.readers.TensorAddUpdateReader.createTensorAddUpdate; +import static com.yahoo.document.json.readers.TensorAddUpdateReader.isTensorField; import static com.yahoo.document.json.readers.TensorModifyUpdateReader.UPDATE_MODIFY; import static com.yahoo.document.json.readers.TensorModifyUpdateReader.createModifyUpdate; @@ -119,7 +121,11 @@ public class VespaJsonDocumentReader { createRemoves(buffer, field, fieldUpdate); break; case UPDATE_ADD: - createAdds(buffer, field, fieldUpdate); + if (isTensorField(field)) { + fieldUpdate.addValueUpdate(createTensorAddUpdate(buffer, field)); + } else { + createAdds(buffer, field, fieldUpdate); + } break; case UPDATE_MATCH: fieldUpdate.addValueUpdate(createMapUpdate(buffer, field)); diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java new file mode 100644 index 00000000000..3703ffc17a2 --- /dev/null +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -0,0 +1,68 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.document.update; + +import com.yahoo.document.DataType; +import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.FieldValue; +import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.document.serialization.DocumentUpdateWriter; + +import java.util.Objects; + +/** + * An update used to add cells to a sparse tensor (has only mapped dimensions). + * + * The cells to add are contained in a sparse tensor as well. + */ +public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { + + private TensorFieldValue tensor; + + public TensorAddUpdate(TensorFieldValue tensor) { + super(ValueUpdateClassID.TENSORADD); + this.tensor = tensor; + } + + @Override + protected void checkCompatibility(DataType fieldType) { + if (!(fieldType instanceof TensorDataType)) { + throw new UnsupportedOperationException("Expected tensor type, got " + fieldType.getName() + "."); + } + } + + @Override + public void serialize(DocumentUpdateWriter data, DataType superType) { + // TODO: implement + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public FieldValue applyTo(FieldValue oldValue) { + // TODO: implement + return null; + } + + @Override + public TensorFieldValue getValue() { + return tensor; + } + + @Override + public void setValue(TensorFieldValue value) { + tensor = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + TensorAddUpdate that = (TensorAddUpdate) o; + return tensor.equals(that.tensor); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), tensor); + } +} diff --git a/document/src/main/java/com/yahoo/document/update/ValueUpdate.java b/document/src/main/java/com/yahoo/document/update/ValueUpdate.java index 9600b820549..4e4dae60589 100644 --- a/document/src/main/java/com/yahoo/document/update/ValueUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/ValueUpdate.java @@ -343,7 +343,8 @@ public abstract class ValueUpdate<T extends FieldValue> { CLEAR(28, "clear"), MAP(29, "map"), REMOVE(30, "remove"), - TENSORMODIFY(100, "tensormodify"); + TENSORMODIFY(100, "tensormodify"), + TENSORADD(101, "tensoradd"); public final int id; public final String name; diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 2588d56e24f..ecbd615a633 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -43,6 +43,7 @@ import com.yahoo.document.update.AssignValueUpdate; import com.yahoo.document.update.ClearValueUpdate; import com.yahoo.document.update.FieldUpdate; import com.yahoo.document.update.MapValueUpdate; +import com.yahoo.document.update.TensorAddUpdate; import com.yahoo.document.update.TensorModifyUpdate; import com.yahoo.document.update.ValueUpdate; import com.yahoo.io.GrowableByteBuffer; @@ -1428,6 +1429,32 @@ public class JsonReaderTestCase { } @Test + public void tensor_add_update_on_sparse_tensor() { + assertTensorAddUpdate("{{x:a,y:b}:2.0, {x:c,y:d}: 3.0}", "sparse_tensor", + inputJson("{", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 },", + " { 'address': { 'x': 'c', 'y': 'd' }, 'value': 3.0 } ]}")); + } + + @Test + public void tensor_add_update_on_non_sparse_tensor_throws() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + createTensorAddUpdate(inputJson("{", + " 'cells': [] }"), "mixed_tensor"); + } + + @Test + public void tensor_add_update_on_not_fully_specified_cell_throws() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Missing a value for dimension y for tensor(x{},y{})"); + createTensorAddUpdate(inputJson("{", + " 'cells': [", + " { 'address': { 'x': 'a' }, 'value': 2.0 } ]}"), "sparse_tensor"); + } + + @Test public void require_that_parser_propagates_datatype_parser_errors_predicate() { assertParserErrorMatches( "Error in document 'id:unittest:testpredicate::0' - could not parse field 'boolean' of type 'predicate': " + @@ -1573,6 +1600,30 @@ public class JsonReaderTestCase { assertEquals(Tensor.from(expectedTensor), modifyUpdate.getValue().getTensor().get()); } + private DocumentUpdate createTensorAddUpdate(String tensorJson, String tensorFieldName) { + JsonReader reader = createReader(inputJson("[", + "{ 'update': '" + TENSOR_DOC_ID + "',", + " 'fields': {", + " '" + tensorFieldName + "': {", + " 'add': " + tensorJson + " }}}]")); + return (DocumentUpdate) reader.next(); + } + + private void assertTensorAddUpdate(String expectedTensor, String tensorFieldName, String tensorJson) { + assertTensorAddUpdate(expectedTensor, tensorFieldName, + createTensorAddUpdate(tensorJson, tensorFieldName)); + } + + private static void assertTensorAddUpdate(String expectedTensor, String tensorFieldName, DocumentUpdate update) { + assertEquals("testtensor", update.getId().getDocType()); + assertEquals(TENSOR_DOC_ID, update.getId().toString()); + assertEquals(1, update.fieldUpdates().size()); + FieldUpdate fieldUpdate = update.getFieldUpdate(tensorFieldName); + assertEquals(1, fieldUpdate.size()); + TensorAddUpdate addUpdate = (TensorAddUpdate) fieldUpdate.getValueUpdate(0); + assertEquals(Tensor.from(expectedTensor), addUpdate.getValue().getTensor().get()); + } + private static FieldUpdate getTensorField(DocumentUpdate update) { FieldUpdate fieldUpdate = update.getFieldUpdate("sparse_tensor"); assertEquals(1, fieldUpdate.size()); |