diff options
13 files changed, 134 insertions, 75 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json index f100178ee16..d4db3026b27 100644 --- a/document/abi-spec.json +++ b/document/abi-spec.json @@ -5278,6 +5278,7 @@ "public boolean equals(java.lang.Object)", "public int hashCode()", "public java.lang.String toString()", + "public static com.yahoo.tensor.TensorType extractSparseDimensions(com.yahoo.tensor.TensorType)", "public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)", "public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()" ], diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java index 5ed4455435a..6310fa62d15 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java @@ -27,11 +27,10 @@ public class TensorAddUpdateReader { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType tensorType = tensorDataType.getTensorType(); - TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType); fillTensor(buffer, tensorFieldValue); - expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get()); + expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get()); return new TensorAddUpdate(tensorFieldValue); } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index aa5fed78bfe..66588debbca 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -29,7 +29,6 @@ public class TensorModifyUpdateReader { private static final String MODIFY_MULTIPLY = "multiply"; public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) { - expectFieldIsOfTypeTensor(field); expectTensorTypeHasNoneIndexedUnboundDimensions(field); expectObjectStart(buffer.currentToken()); diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java index 0d12e7c074b..3bb4b2e262f 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java @@ -28,9 +28,9 @@ public class TensorRemoveUpdateReader { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType originalType = tensorDataType.getTensorType(); - TensorType convertedType = extractSparseDimensions(originalType); - + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType); Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType); + expectAddressesAreNonEmpty(field, tensor); return new TensorRemoveUpdate(new TensorFieldValue(tensor)); } @@ -87,9 +87,4 @@ public class TensorRemoveUpdateReader { return builder.build(); } - public static TensorType extractSparseDimensions(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); - type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); - return builder.build(); - } } diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java index fb252b1a30a..a763db33e7a 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java @@ -63,7 +63,7 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { } TensorDataType tensorDataType = (TensorDataType)type; TensorType tensorType = tensorDataType.getTensorType(); - TensorType convertedType = TensorRemoveUpdateReader.extractSparseDimensions(tensorType); + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType); TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java index d0d0cfc2480..f8d2464deb7 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -12,8 +12,6 @@ import java.util.Objects; /** * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension). - * - * The cells to add are contained in a sparse tensor. */ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { @@ -50,7 +48,7 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); Tensor update = tensor.getTensor().get(); - Tensor result = old.merge((left, right) -> right, update.cells()); + Tensor result = old.merge((left, right) -> right, update.cells()); // note this might be slow for large mixed tensor updates return new TensorFieldValue(result); } diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index 0c5345c7a9f..335cda8e133 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -7,6 +7,7 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.util.Objects; @@ -22,6 +23,18 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { public TensorRemoveUpdate(TensorFieldValue value) { super(ValueUpdateClassID.TENSORREMOVE); this.tensor = value; + verifyCompatibleType(); + } + + private void verifyCompatibleType() { + if ( ! tensor.getTensor().isPresent()) { + throw new IllegalArgumentException("Tensor must be present in remove update"); + } + TensorType tensorType = tensor.getTensor().get().type(); + TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType()); + if ( ! tensorType.equals(expectedType)) { + throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'"); + } } @Override @@ -83,4 +96,11 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { return super.toString() + " " + tensor; } + public static TensorType extractSparseDimensions(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); + return builder.build(); + } + + } diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java index 288bd112cd6..6935c54ba2a 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java @@ -10,32 +10,12 @@ import static org.junit.Assert.assertEquals; public class TensorAddUpdateTest { @Test - public void apply_add_update_operations_sparse() { - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); - assertSparseApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); + public void apply_add_update_operations() { + assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); } - @Test - public void apply_add_update_operations_mixed() { - assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}"); - assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}"); - assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}"); - assertMixedApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5,{x:0,y:1}:0,{x:0,y:2}:0}"); - assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:0}"); - } - - private void assertSparseApplyTo(String init, String update, String expected) { - assertApplyTo("tensor(x{},y{})", init, update, expected); - } - - private void assertMixedApplyTo(String init, String update, String expected) { - assertApplyTo("tensor(x{},y[3])", init, update, expected); - } - - private void assertApplyTo(String spec, String init, String update, String expected) { + private void assertApplyTo(String init, String update, String expected) { + String spec = "tensor(x{},y{})"; TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update))); Tensor updated = ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get(); diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java index 20d0ccbcb3d..b885e6ddca0 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java @@ -41,20 +41,8 @@ public class TensorModifyUpdateTest { public void apply_modify_update_operations() { assertApplyTo("tensor(x{},y{})", Operation.REPLACE, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); - assertApplyTo("tensor(x{},y{})", Operation.ADD, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x{},y{})", Operation.MULTIPLY, - "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); - assertApplyTo("tensor(x[1],y[2])", Operation.REPLACE, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); assertApplyTo("tensor(x[1],y[2])", Operation.ADD, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x[1],y[2])", Operation.MULTIPLY, - "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); - assertApplyTo("tensor(x{},y[2])", Operation.REPLACE, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); - assertApplyTo("tensor(x{},y[2])", Operation.ADD, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); assertApplyTo("tensor(x{},y[2])", Operation.MULTIPLY, "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); } diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java index 52ed6c63356..3a005e858c8 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java @@ -10,32 +10,14 @@ import static org.junit.Assert.assertEquals; public class TensorRemoveUpdateTest { @Test - public void apply_remove_update_operations_sparse() { - assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}"); - assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}"); - assertSparseApplyTo("{}", "{{x:0,y:0}:1}", "{}"); - assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); + public void apply_remove_update_operations() { + assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}"); } - @Test - public void apply_remove_update_operations_mixed() { - assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0}:1}", "{}"); - assertMixedApplyTo("{{x:0,y:0}:1, {x:1,y:0}:2}", "{{x:0}:1}", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}"); - assertMixedApplyTo("{}", "{{x:0}:1}", "{}"); - assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); - } - - private void assertSparseApplyTo(String init, String update, String expected) { - assertApplyTo("tensor(x{},y{})", "tensor(x{},y{})", init, update, expected); - } - - private void assertMixedApplyTo(String init, String update, String expected) { - assertApplyTo("tensor(x{},y[3])", "tensor(x{})", init, update, expected); - } - - private void assertApplyTo(String spec, String updateSpec, String init, String update, String expected) { + private void assertApplyTo(String init, String update, String expected) { + String spec = "tensor(x{},y{})"; TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); - TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(updateSpec, update))); + TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(spec, update))); TensorFieldValue updatedFieldValue = (TensorFieldValue) removeUpdate.applyTo(initialFieldValue); assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index c51d3c32df9..08878edeb83 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -130,9 +130,11 @@ public class MixedTensor implements Tensor { @Override public Tensor remove(Set<TensorAddress> addresses) { Tensor.Builder builder = Tensor.Builder.of(type()); + + // iterate through all sparse addresses referencing a dense subspace for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) { TensorAddress sparsePartialAddress = entry.getKey(); - if ( ! addresses.contains(sparsePartialAddress)) { + if ( ! addresses.contains(sparsePartialAddress)) { // assumption: addresses only contain the sparse part long offset = entry.getValue(); for (int i = 0; i < index.denseSubspaceSize; ++i) { Cell cell = cells.get((int)offset + i); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index a2333f41135..eb16801c306 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -129,7 +129,8 @@ public interface Tensor { /** * Returns a new tensor where existing cells in this tensor have been * removed according to the given set of addresses. Only valid for sparse - * or mixed tensors. + * or mixed tensors. For mixed tensors, addresses are assumed to only + * contain the sparse dimensions, as the entire dense subspace is removed. * * @param addresses list of addresses to remove * @return a new tensor where cells have been removed diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 2c9eefbd130..02d16e6f3e4 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -151,12 +151,106 @@ public class TensorTestCase { Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), Tensor.from("tensor(x[1],y[3])", "{}"), Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}")); + assertTensorModify((left, right) -> left * right, + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:6}")); + } + + @Test + public void testTensorMerge() { + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:2}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3,{x:0,y:2}:4}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:2}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}")); // notice difference with sparse case - y is dense dimension here with default value 0.0 + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3,{x:0,y:2}:4}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}")); + } + + @Test + public void testTensorRemove() { + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:1}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:1}"), + Tensor.from("tensor(x{},y{})", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1}"), + Tensor.from("tensor(x{},y{})", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2, {x:0,y:1}:3}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), // notice update is without dense dimension + Tensor.from("tensor(x{},y[3])", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:1,y:0}:2}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), + Tensor.from("tensor(x{},y[3])", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), + Tensor.from("tensor(x{},y[3])", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{})", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}")); } private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) { assertEquals(expected, init.modify(op, update.cells())); } + private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) { + DoubleBinaryOperator op = (left, right) -> right; + assertEquals(expected, init.merge(op, update.cells())); + } + + private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) { + assertEquals(expected, init.remove(update.cells().keySet())); + } + + private double dotProduct(Tensor tensor, List<Tensor> tensors) { double sum = 0; TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), |