diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-20 14:30:31 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-20 14:30:31 +0100 |
commit | c85a3fee56c13f82d14d480e7569432e1f352316 (patch) | |
tree | 1ba19b8b498a7c4e0004939a8139fcfbd8d75875 /document | |
parent | 085b6922c07f4626c61e2ed2e6dde6beec0855de (diff) |
TensorRemoveUpdate support for mixed tensors
Diffstat (limited to 'document')
6 files changed, 107 insertions, 37 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java index 210a6a80ee5..0d12e7c074b 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java @@ -24,23 +24,23 @@ public class TensorRemoveUpdateReader { static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); - expectTensorTypeIsSparse(field); + expectTensorTypeHasSparseDimensions(field); TensorDataType tensorDataType = (TensorDataType)field.getDataType(); - TensorType tensorType = tensorDataType.getTensorType(); + TensorType originalType = tensorDataType.getTensorType(); + TensorType convertedType = extractSparseDimensions(originalType); - // TODO: for mixed case extract a new tensor type based only on mapped dimensions - - Tensor tensor = readRemoveUpdateTensor(buffer, tensorType); + Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType); expectAddressesAreNonEmpty(field, tensor); return new TensorRemoveUpdate(new TensorFieldValue(tensor)); } - private static void expectTensorTypeIsSparse(Field field) { + private static void expectTensorTypeHasSparseDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - if (tensorType.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed)) { - throw new IllegalArgumentException("A remove update can only be applied to sparse tensors. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { + throw new IllegalArgumentException("A remove update can only be applied to tensors " + + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -53,7 +53,7 @@ public class TensorRemoveUpdateReader { /** * Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0 */ - private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) { + private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) { Tensor.Builder builder = Tensor.Builder.of(type); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); @@ -62,7 +62,7 @@ public class TensorRemoveUpdateReader { expectArrayStart(buffer.currentToken()); int nesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) { - builder.cell(readTensorAddress(buffer, type), 1.0); + builder.cell(readTensorAddress(buffer, type, originalType), 1.0); } expectCompositeEnd(buffer.currentToken()); } @@ -71,12 +71,15 @@ public class TensorRemoveUpdateReader { return builder.build(); } - private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) { + private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) { TensorAddress.Builder builder = new TensorAddress.Builder(type); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { String dimension = buffer.currentName(); + if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) { + throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update"); + } String label = buffer.currentText(); builder.add(dimension, label); } @@ -84,4 +87,9 @@ public class TensorRemoveUpdateReader { return builder.build(); } + public static TensorType extractSparseDimensions(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); + return builder.build(); + } } diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java index 2f22def9aa1..2ab7169fae2 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java @@ -5,6 +5,7 @@ import com.yahoo.document.DataType; import com.yahoo.document.DocumentTypeManager; import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.document.json.readers.TensorRemoveUpdateReader; import com.yahoo.document.update.TensorAddUpdate; import com.yahoo.document.update.TensorModifyUpdate; import com.yahoo.document.update.TensorRemoveUpdate; @@ -46,7 +47,10 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { throw new DeserializationException("Expected tensor data type, got " + type); } TensorDataType tensorDataType = (TensorDataType)type; - TensorFieldValue tensor = new TensorFieldValue(tensorDataType.getTensorType()); + TensorType tensorType = tensorDataType.getTensorType(); + TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(tensorType); + + TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); return new TensorAddUpdate(tensor); } @@ -58,10 +62,9 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { } TensorDataType tensorDataType = (TensorDataType)type; TensorType tensorType = tensorDataType.getTensorType(); + TensorType convertedType = TensorRemoveUpdateReader.extractSparseDimensions(tensorType); - // TODO: for mixed case extract a new tensor type based only on mapped dimensions - - TensorFieldValue tensor = new TensorFieldValue(tensorType); + TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); return new TensorRemoveUpdate(tensor); } diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index e9fb1e3efd5..fb046f15c2c 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -51,17 +51,10 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } - Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); - Map<TensorAddress, Double> cellsToRemove = tensor.getTensor().get().cells(); - Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); - for (Iterator<Tensor.Cell> i = oldTensor.cellIterator(); i.hasNext(); ) { - Tensor.Cell cell = i.next(); - TensorAddress address = cell.getKey(); - if ( ! cellsToRemove.containsKey(address)) { - builder.cell(address, cell.getValue()); - } - } - return new TensorFieldValue(builder.build()); + Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); + Tensor update = tensor.getTensor().get(); + Tensor result = old.remove(update.cells().keySet()); + return new TensorFieldValue(result); } @Override diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java index e2736dabd2b..01293cb9782 100644 --- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java +++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java @@ -40,6 +40,7 @@ public class DocumentUpdateJsonSerializerTest { final static TensorType sparseTensorType = new TensorType.Builder().mapped("x").mapped("y").build(); final static TensorType denseTensorType = new TensorType.Builder().indexed("x", 2).indexed("y", 3).build(); + final static TensorType mixedTensorType = new TensorType.Builder().mapped("x").indexed("y", 3).build(); final static DocumentTypeManager types = new DocumentTypeManager(); final static JsonFactory parserFactory = new JsonFactory(); final static DocumentType docType = new DocumentType("doctype"); @@ -60,6 +61,7 @@ public class DocumentUpdateJsonSerializerTest { docType.addField(new Field("byte_field", DataType.BYTE)); docType.addField(new Field("sparse_tensor", new TensorDataType(sparseTensorType))); docType.addField(new Field("dense_tensor", new TensorDataType(denseTensorType))); + docType.addField(new Field("mixed_tensor", new TensorDataType(mixedTensorType))); docType.addField(new Field("reference_field", new ReferenceDataType(refTargetDocType, 777))); docType.addField(new Field("predicate_field", DataType.PREDICATE)); docType.addField(new Field("raw_field", DataType.RAW)); @@ -355,6 +357,25 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_add_update_mixed() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'add': {", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': '1', 'y': '2' }, 'value': 3.0 }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void test_tensor_remove_update() { roundtripSerializeJsonAndMatch(inputJson( "{", @@ -374,6 +395,25 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_remove_update_mixed() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'remove': {", + " 'addresses': [", + " {'x':'0' }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + + @Test public void reference_field_id_can_be_update_assigned_non_empty_id() { roundtripSerializeJsonAndMatch(inputJson( "{", diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index a20276e5c65..fe24a755d1d 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1489,6 +1489,7 @@ public class JsonReaderTestCase { exception.expect(IllegalArgumentException.class); exception.expectMessage("Add update for field 'sparse_tensor' does not contain tensor cells"); createTensorAddUpdate(inputJson("{}"), "sparse_tensor"); + createTensorAddUpdate(inputJson("{}"), "mixed_tensor"); } @Test @@ -1500,7 +1501,6 @@ public class JsonReaderTestCase { " { 'x': 'c', 'y': 'd' } ]}")); } - @Ignore @Test public void tensor_remove_update_on_mixed_tensor() { assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor", @@ -1511,9 +1511,19 @@ public class JsonReaderTestCase { } @Test + public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Indexed dimension address 'y' should not be specified in remove update"); + createTensorRemoveUpdate(inputJson("{", + " 'addresses': [", + " { 'x': '1', 'y': '0' },", + " { 'x': '2', 'y': '0' } ]}"), "mixed_tensor"); + } + + @Test public void tensor_remove_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); + exception.expectMessage("A remove update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorRemoveUpdate(inputJson("{", " 'addresses': [] }"), "dense_tensor"); } @@ -1532,6 +1542,7 @@ public class JsonReaderTestCase { exception.expect(IllegalArgumentException.class); exception.expectMessage("Remove update for field 'sparse_tensor' does not contain tensor addresses"); createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "sparse_tensor"); + createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "mixed_tensor"); } @Test diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java index 40ab00facdb..52ed6c63356 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java @@ -10,17 +10,32 @@ import static org.junit.Assert.assertEquals; public class TensorRemoveUpdateTest { @Test - public void apply_remove_update_operations() { - assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}"); - assertApplyTo("{}", "{{x:0,y:0}:1}", "{}"); - assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); + public void apply_remove_update_operations_sparse() { + assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}"); + assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}"); + assertSparseApplyTo("{}", "{{x:0,y:0}:1}", "{}"); + assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); } - private void assertApplyTo(String init, String update, String expected) { - String spec = "tensor(x{},y{})"; + @Test + public void apply_remove_update_operations_mixed() { + assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0}:1}", "{}"); + assertMixedApplyTo("{{x:0,y:0}:1, {x:1,y:0}:2}", "{{x:0}:1}", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}"); + assertMixedApplyTo("{}", "{{x:0}:1}", "{}"); + assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); + } + + private void assertSparseApplyTo(String init, String update, String expected) { + assertApplyTo("tensor(x{},y{})", "tensor(x{},y{})", init, update, expected); + } + + private void assertMixedApplyTo(String init, String update, String expected) { + assertApplyTo("tensor(x{},y[3])", "tensor(x{})", init, update, expected); + } + + private void assertApplyTo(String spec, String updateSpec, String init, String update, String expected) { TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); - TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(spec, update))); + TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(updateSpec, update))); TensorFieldValue updatedFieldValue = (TensorFieldValue) removeUpdate.applyTo(initialFieldValue); assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); } |