From 21651a8420530f069d42f37ca4dd0381f043501a Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 21 Feb 2019 10:55:55 +0100 Subject: Cleanup of tensor updates - Java --- document/abi-spec.json | 1 + .../json/readers/TensorAddUpdateReader.java | 3 +-- .../json/readers/TensorModifyUpdateReader.java | 1 - .../json/readers/TensorRemoveUpdateReader.java | 9 ++----- .../VespaDocumentDeserializerHead.java | 2 +- .../com/yahoo/document/update/TensorAddUpdate.java | 4 +--- .../yahoo/document/update/TensorRemoveUpdate.java | 20 ++++++++++++++++ .../yahoo/document/update/TensorAddUpdateTest.java | 28 ++++------------------ .../document/update/TensorModifyUpdateTest.java | 12 ---------- .../document/update/TensorRemoveUpdateTest.java | 28 ++++------------------ 10 files changed, 35 insertions(+), 73 deletions(-) (limited to 'document') diff --git a/document/abi-spec.json b/document/abi-spec.json index f100178ee16..d4db3026b27 100644 --- a/document/abi-spec.json +++ b/document/abi-spec.json @@ -5278,6 +5278,7 @@ "public boolean equals(java.lang.Object)", "public int hashCode()", "public java.lang.String toString()", + "public static com.yahoo.tensor.TensorType extractSparseDimensions(com.yahoo.tensor.TensorType)", "public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)", "public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()" ], diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java index 5ed4455435a..6310fa62d15 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java @@ -27,11 +27,10 @@ public class TensorAddUpdateReader { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType tensorType = tensorDataType.getTensorType(); - TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType); fillTensor(buffer, tensorFieldValue); - expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get()); + expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get()); return new TensorAddUpdate(tensorFieldValue); } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index aa5fed78bfe..66588debbca 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -29,7 +29,6 @@ public class TensorModifyUpdateReader { private static final String MODIFY_MULTIPLY = "multiply"; public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) { - expectFieldIsOfTypeTensor(field); expectTensorTypeHasNoneIndexedUnboundDimensions(field); expectObjectStart(buffer.currentToken()); diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java index 0d12e7c074b..3bb4b2e262f 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java @@ -28,9 +28,9 @@ public class TensorRemoveUpdateReader { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType originalType = tensorDataType.getTensorType(); - TensorType convertedType = extractSparseDimensions(originalType); - + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType); Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType); + expectAddressesAreNonEmpty(field, tensor); return new TensorRemoveUpdate(new TensorFieldValue(tensor)); } @@ -87,9 +87,4 @@ public class TensorRemoveUpdateReader { return builder.build(); } - public static TensorType extractSparseDimensions(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); - type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); - return builder.build(); - } } diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java index fb252b1a30a..a763db33e7a 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java @@ -63,7 +63,7 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { } TensorDataType tensorDataType = (TensorDataType)type; TensorType tensorType = tensorDataType.getTensorType(); - TensorType convertedType = TensorRemoveUpdateReader.extractSparseDimensions(tensorType); + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType); TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java index d0d0cfc2480..f8d2464deb7 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -12,8 +12,6 @@ import java.util.Objects; /** * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension). - * - * The cells to add are contained in a sparse tensor. */ public class TensorAddUpdate extends ValueUpdate { @@ -50,7 +48,7 @@ public class TensorAddUpdate extends ValueUpdate { Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); Tensor update = tensor.getTensor().get(); - Tensor result = old.merge((left, right) -> right, update.cells()); + Tensor result = old.merge((left, right) -> right, update.cells()); // note this might be slow for large mixed tensor updates return new TensorFieldValue(result); } diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index 0c5345c7a9f..335cda8e133 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -7,6 +7,7 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.util.Objects; @@ -22,6 +23,18 @@ public class TensorRemoveUpdate extends ValueUpdate { public TensorRemoveUpdate(TensorFieldValue value) { super(ValueUpdateClassID.TENSORREMOVE); this.tensor = value; + verifyCompatibleType(); + } + + private void verifyCompatibleType() { + if ( ! tensor.getTensor().isPresent()) { + throw new IllegalArgumentException("Tensor must be present in remove update"); + } + TensorType tensorType = tensor.getTensor().get().type(); + TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType()); + if ( ! tensorType.equals(expectedType)) { + throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'"); + } } @Override @@ -83,4 +96,11 @@ public class TensorRemoveUpdate extends ValueUpdate { return super.toString() + " " + tensor; } + public static TensorType extractSparseDimensions(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); + return builder.build(); + } + + } diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java index 288bd112cd6..6935c54ba2a 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java @@ -10,32 +10,12 @@ import static org.junit.Assert.assertEquals; public class TensorAddUpdateTest { @Test - public void apply_add_update_operations_sparse() { - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); - assertSparseApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); + public void apply_add_update_operations() { + assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); } - @Test - public void apply_add_update_operations_mixed() { - assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}"); - assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}"); - assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}"); - assertMixedApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5,{x:0,y:1}:0,{x:0,y:2}:0}"); - assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:0}"); - } - - private void assertSparseApplyTo(String init, String update, String expected) { - assertApplyTo("tensor(x{},y{})", init, update, expected); - } - - private void assertMixedApplyTo(String init, String update, String expected) { - assertApplyTo("tensor(x{},y[3])", init, update, expected); - } - - private void assertApplyTo(String spec, String init, String update, String expected) { + private void assertApplyTo(String init, String update, String expected) { + String spec = "tensor(x{},y{})"; TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update))); Tensor updated = ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get(); diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java index 20d0ccbcb3d..b885e6ddca0 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java @@ -41,20 +41,8 @@ public class TensorModifyUpdateTest { public void apply_modify_update_operations() { assertApplyTo("tensor(x{},y{})", Operation.REPLACE, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); - assertApplyTo("tensor(x{},y{})", Operation.ADD, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x{},y{})", Operation.MULTIPLY, - "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); - assertApplyTo("tensor(x[1],y[2])", Operation.REPLACE, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); assertApplyTo("tensor(x[1],y[2])", Operation.ADD, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x[1],y[2])", Operation.MULTIPLY, - "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); - assertApplyTo("tensor(x{},y[2])", Operation.REPLACE, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); - assertApplyTo("tensor(x{},y[2])", Operation.ADD, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); assertApplyTo("tensor(x{},y[2])", Operation.MULTIPLY, "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); } diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java index 52ed6c63356..3a005e858c8 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java @@ -10,32 +10,14 @@ import static org.junit.Assert.assertEquals; public class TensorRemoveUpdateTest { @Test - public void apply_remove_update_operations_sparse() { - assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}"); - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}"); - assertSparseApplyTo("{}", "{{x:0,y:0}:1}", "{}"); - assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); + public void apply_remove_update_operations() { + assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}"); } - @Test - public void apply_remove_update_operations_mixed() { - assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0}:1}", "{}"); - assertMixedApplyTo("{{x:0,y:0}:1, {x:1,y:0}:2}", "{{x:0}:1}", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}"); - assertMixedApplyTo("{}", "{{x:0}:1}", "{}"); - assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); - } - - private void assertSparseApplyTo(String init, String update, String expected) { - assertApplyTo("tensor(x{},y{})", "tensor(x{},y{})", init, update, expected); - } - - private void assertMixedApplyTo(String init, String update, String expected) { - assertApplyTo("tensor(x{},y[3])", "tensor(x{})", init, update, expected); - } - - private void assertApplyTo(String spec, String updateSpec, String init, String update, String expected) { + private void assertApplyTo(String init, String update, String expected) { + String spec = "tensor(x{},y{})"; TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); - TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(updateSpec, update))); + TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(spec, update))); TensorFieldValue updatedFieldValue = (TensorFieldValue) removeUpdate.applyTo(initialFieldValue); assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); } -- cgit v1.2.3