summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-02-21 10:55:55 +0100
committerLester Solbakken <lesters@oath.com>2019-02-21 10:55:55 +0100
commit21651a8420530f069d42f37ca4dd0381f043501a (patch)
tree41a1d392c21fd43ffcb617d337490890c254d1b2
parent5d0bff5230d3d8a304f786cbcc3c486ee9f941bb (diff)
Cleanup of tensor updates - Java
-rw-r--r--document/abi-spec.json1
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java3
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java1
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java9
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java2
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java4
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java20
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java28
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java12
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java3
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java94
13 files changed, 134 insertions, 75 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json
index f100178ee16..d4db3026b27 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -5278,6 +5278,7 @@
"public boolean equals(java.lang.Object)",
"public int hashCode()",
"public java.lang.String toString()",
+ "public static com.yahoo.tensor.TensorType extractSparseDimensions(com.yahoo.tensor.TensorType)",
"public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)",
"public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()"
],
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
index 5ed4455435a..6310fa62d15 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
@@ -27,11 +27,10 @@ public class TensorAddUpdateReader {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType tensorType = tensorDataType.getTensorType();
-
TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType);
fillTensor(buffer, tensorFieldValue);
- expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get());
+ expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get());
return new TensorAddUpdate(tensorFieldValue);
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index aa5fed78bfe..66588debbca 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -29,7 +29,6 @@ public class TensorModifyUpdateReader {
private static final String MODIFY_MULTIPLY = "multiply";
public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
-
expectFieldIsOfTypeTensor(field);
expectTensorTypeHasNoneIndexedUnboundDimensions(field);
expectObjectStart(buffer.currentToken());
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 0d12e7c074b..3bb4b2e262f 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -28,9 +28,9 @@ public class TensorRemoveUpdateReader {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
- TensorType convertedType = extractSparseDimensions(originalType);
-
+ TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType);
Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
+
expectAddressesAreNonEmpty(field, tensor);
return new TensorRemoveUpdate(new TensorFieldValue(tensor));
}
@@ -87,9 +87,4 @@ public class TensorRemoveUpdateReader {
return builder.build();
}
- public static TensorType extractSparseDimensions(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder();
- type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
- return builder.build();
- }
}
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index fb252b1a30a..a763db33e7a 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -63,7 +63,7 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
}
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
- TensorType convertedType = TensorRemoveUpdateReader.extractSparseDimensions(tensorType);
+ TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType);
TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
index d0d0cfc2480..f8d2464deb7 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
@@ -12,8 +12,6 @@ import java.util.Objects;
/**
* An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension).
- *
- * The cells to add are contained in a sparse tensor.
*/
public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
@@ -50,7 +48,7 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
Tensor update = tensor.getTensor().get();
- Tensor result = old.merge((left, right) -> right, update.cells());
+ Tensor result = old.merge((left, right) -> right, update.cells()); // note this might be slow for large mixed tensor updates
return new TensorFieldValue(result);
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index 0c5345c7a9f..335cda8e133 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -7,6 +7,7 @@ import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import java.util.Objects;
@@ -22,6 +23,18 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
public TensorRemoveUpdate(TensorFieldValue value) {
super(ValueUpdateClassID.TENSORREMOVE);
this.tensor = value;
+ verifyCompatibleType();
+ }
+
+ private void verifyCompatibleType() {
+ if ( ! tensor.getTensor().isPresent()) {
+ throw new IllegalArgumentException("Tensor must be present in remove update");
+ }
+ TensorType tensorType = tensor.getTensor().get().type();
+ TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType());
+ if ( ! tensorType.equals(expectedType)) {
+ throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'");
+ }
}
@Override
@@ -83,4 +96,11 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return super.toString() + " " + tensor;
}
+ public static TensorType extractSparseDimensions(TensorType type) {
+ TensorType.Builder builder = new TensorType.Builder();
+ type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
+ return builder.build();
+ }
+
+
}
diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
index 288bd112cd6..6935c54ba2a 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
@@ -10,32 +10,12 @@ import static org.junit.Assert.assertEquals;
public class TensorAddUpdateTest {
@Test
- public void apply_add_update_operations_sparse() {
- assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}");
- assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}");
- assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}");
- assertSparseApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}");
- assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}");
+ public void apply_add_update_operations() {
+ assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}");
}
- @Test
- public void apply_add_update_operations_mixed() {
- assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}");
- assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}");
- assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}");
- assertMixedApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5,{x:0,y:1}:0,{x:0,y:2}:0}");
- assertMixedApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:0}");
- }
-
- private void assertSparseApplyTo(String init, String update, String expected) {
- assertApplyTo("tensor(x{},y{})", init, update, expected);
- }
-
- private void assertMixedApplyTo(String init, String update, String expected) {
- assertApplyTo("tensor(x{},y[3])", init, update, expected);
- }
-
- private void assertApplyTo(String spec, String init, String update, String expected) {
+ private void assertApplyTo(String init, String update, String expected) {
+ String spec = "tensor(x{},y{})";
TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init));
TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update)));
Tensor updated = ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get();
diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
index 20d0ccbcb3d..b885e6ddca0 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
@@ -41,20 +41,8 @@ public class TensorModifyUpdateTest {
public void apply_modify_update_operations() {
assertApplyTo("tensor(x{},y{})", Operation.REPLACE,
"{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}");
- assertApplyTo("tensor(x{},y{})", Operation.ADD,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}");
- assertApplyTo("tensor(x{},y{})", Operation.MULTIPLY,
- "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
- assertApplyTo("tensor(x[1],y[2])", Operation.REPLACE,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}");
assertApplyTo("tensor(x[1],y[2])", Operation.ADD,
"{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}");
- assertApplyTo("tensor(x[1],y[2])", Operation.MULTIPLY,
- "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
- assertApplyTo("tensor(x{},y[2])", Operation.REPLACE,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}");
- assertApplyTo("tensor(x{},y[2])", Operation.ADD,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}");
assertApplyTo("tensor(x{},y[2])", Operation.MULTIPLY,
"{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
}
diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
index 52ed6c63356..3a005e858c8 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
@@ -10,32 +10,14 @@ import static org.junit.Assert.assertEquals;
public class TensorRemoveUpdateTest {
@Test
- public void apply_remove_update_operations_sparse() {
- assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}");
- assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}");
- assertSparseApplyTo("{}", "{{x:0,y:0}:1}", "{}");
- assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}");
+ public void apply_remove_update_operations() {
+ assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}");
}
- @Test
- public void apply_remove_update_operations_mixed() {
- assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0}:1}", "{}");
- assertMixedApplyTo("{{x:0,y:0}:1, {x:1,y:0}:2}", "{{x:0}:1}", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}");
- assertMixedApplyTo("{}", "{{x:0}:1}", "{}");
- assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}");
- }
-
- private void assertSparseApplyTo(String init, String update, String expected) {
- assertApplyTo("tensor(x{},y{})", "tensor(x{},y{})", init, update, expected);
- }
-
- private void assertMixedApplyTo(String init, String update, String expected) {
- assertApplyTo("tensor(x{},y[3])", "tensor(x{})", init, update, expected);
- }
-
- private void assertApplyTo(String spec, String updateSpec, String init, String update, String expected) {
+ private void assertApplyTo(String init, String update, String expected) {
+ String spec = "tensor(x{},y{})";
TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init));
- TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(updateSpec, update)));
+ TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(spec, update)));
TensorFieldValue updatedFieldValue = (TensorFieldValue) removeUpdate.applyTo(initialFieldValue);
assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get());
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index c51d3c32df9..08878edeb83 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -130,9 +130,11 @@ public class MixedTensor implements Tensor {
@Override
public Tensor remove(Set<TensorAddress> addresses) {
Tensor.Builder builder = Tensor.Builder.of(type());
+
+ // iterate through all sparse addresses referencing a dense subspace
for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) {
TensorAddress sparsePartialAddress = entry.getKey();
- if ( ! addresses.contains(sparsePartialAddress)) {
+ if ( ! addresses.contains(sparsePartialAddress)) { // assumption: addresses only contain the sparse part
long offset = entry.getValue();
for (int i = 0; i < index.denseSubspaceSize; ++i) {
Cell cell = cells.get((int)offset + i);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index a2333f41135..eb16801c306 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -129,7 +129,8 @@ public interface Tensor {
/**
* Returns a new tensor where existing cells in this tensor have been
* removed according to the given set of addresses. Only valid for sparse
- * or mixed tensors.
+ * or mixed tensors. For mixed tensors, addresses are assumed to only
+ * contain the sparse dimensions, as the entire dense subspace is removed.
*
* @param addresses list of addresses to remove
* @return a new tensor where cells have been removed
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 2c9eefbd130..02d16e6f3e4 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -151,12 +151,106 @@ public class TensorTestCase {
Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
Tensor.from("tensor(x[1],y[3])", "{}"),
Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}"));
+ assertTensorModify((left, right) -> left * right,
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:6}"));
+ }
+
+ @Test
+ public void testTensorMerge() {
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:2}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3,{x:0,y:2}:4}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:2}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}")); // notice difference with sparse case - y is dense dimension here with default value 0.0
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3,{x:0,y:2}:4}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"));
+ }
+
+ @Test
+ public void testTensorRemove() {
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:1}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:1}"),
+ Tensor.from("tensor(x{},y{})", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1}"),
+ Tensor.from("tensor(x{},y{})", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2, {x:0,y:1}:3}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"), // notice update is without dense dimension
+ Tensor.from("tensor(x{},y[3])", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:1,y:0}:2}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"),
+ Tensor.from("tensor(x{},y[3])", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{})", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"));
}
private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) {
assertEquals(expected, init.modify(op, update.cells()));
}
+ private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) {
+ DoubleBinaryOperator op = (left, right) -> right;
+ assertEquals(expected, init.merge(op, update.cells()));
+ }
+
+ private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) {
+ assertEquals(expected, init.remove(update.cells().keySet()));
+ }
+
+
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),