summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-02-20 12:46:24 +0100
committerLester Solbakken <lesters@oath.com>2019-02-20 12:46:24 +0100
commit085b6922c07f4626c61e2ed2e6dde6beec0855de (patch)
tree597fc14c08199339c9ab9286c365af6e8d4cdcdb
parent85e394563c8b711a1a0307c8ac5953c1817f5629 (diff)
TensorAddUpdate support for mixed tensors
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java27
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java23
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java24
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java41
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java33
-rw-r--r--vespajlib/abi-spec.json4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java14
10 files changed, 158 insertions, 55 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
index ffbfe49347c..da8bcc13397 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
@@ -6,10 +6,15 @@ import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
import com.yahoo.document.update.TensorAddUpdate;
+import com.yahoo.document.update.TensorModifyUpdate;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.util.Iterator;
+
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart;
+import static com.yahoo.document.json.readers.TensorModifyUpdateReader.validateBounds;
import static com.yahoo.document.json.readers.TensorReader.fillTensor;
/**
@@ -23,22 +28,27 @@ public class TensorAddUpdateReader {
public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
- expectTensorTypeIsSparse(field);
+ expectTensorTypeHasSparseDimensions(field);
+ // Convert update type to sparse
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
- TensorType tensorType = tensorDataType.getTensorType();
- TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType);
+ TensorType originalType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType);
+
+ TensorFieldValue tensorFieldValue = new TensorFieldValue(convertedType);
fillTensor(buffer, tensorFieldValue);
expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get());
+ validateBounds(tensorFieldValue.getTensor().get(), originalType);
+
return new TensorAddUpdate(tensorFieldValue);
}
- private static void expectTensorTypeIsSparse(Field field) {
+ private static void expectTensorTypeHasSparseDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- if (tensorType.dimensions().stream()
- .anyMatch(dim -> dim.isIndexed())) {
- throw new IllegalArgumentException("An add update can only be applied to sparse tensors. "
- + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
+ throw new IllegalArgumentException("An add update can only be applied to tensors " +
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
@@ -48,5 +58,4 @@ public class TensorAddUpdateReader {
}
}
-
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index a9bbba519bd..5022185e03f 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -129,25 +129,26 @@ public class TensorModifyUpdateReader {
validateBounds(tensor, originalType);
- TensorFieldValue result = new TensorFieldValue(convertedType);
- result.assign(tensor);
- return result;
+ return new TensorFieldValue(tensor);
}
- /** Only validate if original type is indexed bound */
- private static void validateBounds(Tensor convertedTensor, TensorType originalType) {
- if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
+ /** Only validate if original type has indexed bound dimensions */
+ static void validateBounds(Tensor convertedTensor, TensorType originalType) {
+ if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
return;
}
for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) {
Tensor.Cell cell = iter.next();
TensorAddress address = cell.getKey();
for (int i = 0; i < address.size(); ++i) {
- long label = address.numericLabel(i);
- long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound
- if (label >= bound) {
- throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
- "' has label '" + label + "' but type is " + originalType.toString());
+ TensorType.Dimension dim = originalType.dimensions().get(i);
+ if (dim instanceof TensorType.IndexedBoundDimension) {
+ long label = address.numericLabel(i);
+ long bound = dim.size().get(); // size is non-optional for indexed bound
+ if (label >= bound) {
+ throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
+ "' has label '" + label + "' but type is " + originalType.toString());
+ }
}
}
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
index cfc3ee0c742..7059edbca7f 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
@@ -13,9 +13,9 @@ import java.util.Map;
import java.util.Objects;
/**
- * An update used to add cells to a sparse tensor (has only mapped dimensions).
+ * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension).
*
- * The cells to add are contained in a sparse tensor as well.
+ * The cells to add are contained in a sparse tensor.
*/
public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
@@ -50,22 +50,10 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
- Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get();
- Map<TensorAddress, Double> oldCells = oldTensor.cells();
- Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells();
-
- // currently, underlying implementation disallows multiple entries with the same key
-
- Tensor.Builder builder = Tensor.Builder.of(oldTensor.type());
- for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) {
- builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue()));
- }
- for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
- if ( ! oldCells.containsKey(addCell.getKey())) {
- builder.cell(addCell.getKey(), addCell.getValue());
- }
- }
- return new TensorFieldValue(builder.build());
+ Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
+ Tensor update = tensor.getTensor().get();
+ Tensor result = old.merge((left, right) -> right, update.cells());
+ return new TensorFieldValue(result);
}
@Override
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index e58b26d500d..a20276e5c65 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -56,6 +56,7 @@ import com.yahoo.text.Utf8;
import org.apache.commons.codec.binary.Base64;
import org.junit.After;
import org.junit.Before;
+import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -1449,11 +1450,29 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_add_update_on_non_sparse_tensor_throws() {
+ public void tensor_add_update_on_mixed_tensor() {
+ assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0}", "mixed_tensor",
+ inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_add_update_with_out_of_bound_dense_cells_throws() {
+ exception.expect(IndexOutOfBoundsException.class);
+ exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x{},y[3])");
+ createTensorAddUpdate(inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor");
+ }
+
+ @Test
+ public void tensor_add_update_on_dense_tensor_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'");
+ exception.expectMessage("An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
createTensorAddUpdate(inputJson("{",
- " 'cells': [] }"), "mixed_tensor");
+ " 'cells': [] }"), "dense_tensor");
}
@Test
@@ -1481,12 +1500,22 @@ public class JsonReaderTestCase {
" { 'x': 'c', 'y': 'd' } ]}"));
}
+ @Ignore
+ @Test
+ public void tensor_remove_update_on_mixed_tensor() {
+ assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor",
+ inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1' },",
+ " { 'x': '2' } ]}"));
+ }
+
@Test
- public void tensor_remove_update_on_non_sparse_tensor_throws() {
+ public void tensor_remove_update_on_dense_tensor_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'");
+ exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
createTensorRemoveUpdate(inputJson("{",
- " 'addresses': [] }"), "mixed_tensor");
+ " 'addresses': [] }"), "dense_tensor");
}
@Test
diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
index eb4001e6415..c6b21380e4b 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
@@ -3,27 +3,40 @@ package com.yahoo.document.update;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.tensor.Tensor;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
import static org.junit.Assert.assertEquals;
public class TensorAddUpdateTest {
+ @Rule
+ public ExpectedException exception = ExpectedException.none();
+
@Test
public void apply_add_update_operations() {
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}");
- assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}");
+ assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}");
+ assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}");
+ assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}");
+ assertApplyTo("tensor(x{},y{})", "{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}");
+ assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}");
+
+ assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}");
+ assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:0}");
+ assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}");
+ assertApplyTo("tensor(x{},y[3])", "{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5,{x:0,y:1}:0,{x:0,y:2}:0}");
+ assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:0}");
}
- private void assertApplyTo(String init, String update, String expected) {
- String spec = "tensor(x{},y{})";
+ private Tensor updateField(String spec, String init, String update) {
TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init));
- TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update)));
- TensorFieldValue updatedFieldValue = (TensorFieldValue) addUpdate.applyTo(initialFieldValue);
- assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get());
+ TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from("tensor(x{},y{})", update)));
+ return ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get();
+ }
+
+ private void assertApplyTo(String spec, String init, String update, String expected) {
+ assertEquals(Tensor.from(spec, expected), updateField(spec, init, update));
}
}
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 932513f8a57..480523982fa 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -808,6 +808,7 @@
"public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -852,6 +853,7 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)"
@@ -937,6 +939,7 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1039,6 +1042,7 @@
"public double asDouble()",
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index fb55b2d5014..704cead7c01 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -13,6 +13,7 @@ import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
/**
* An indexed (dense) tensor backed by a double array.
@@ -190,6 +191,11 @@ public class IndexedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) {
+ throw new IllegalArgumentException("Merge is not supported for indexed tensors");
+ }
+
+ @Override
public int hashCode() { return Arrays.hashCode(values); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index ec3020a1a4e..f44b3ce13b7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.Map;
+import java.util.function.DoubleBinaryOperator;
/**
* A sparse implementation of a tensor backed by a Map of cells to values.
@@ -51,6 +52,25 @@ public class MappedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
+
+ // currently, underlying implementation disallows multiple entries with the same key
+
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) {
+ TensorAddress address = cell.getKey();
+ double value = cell.getValue();
+ builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
+ }
+ for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
+ if ( ! cells.containsKey(addCell.getKey())) {
+ builder.cell(addCell.getKey(), addCell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 17e33c58a13..3630a016691 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -9,6 +9,7 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.function.DoubleBinaryOperator;
import java.util.stream.Collectors;
/**
@@ -70,13 +71,17 @@ public class MixedTensor implements Tensor {
return cells.iterator();
}
+ private Iterable<Cell> cellIterable() {
+ return this::cellIterator;
+ }
+
/**
* Returns an iterator over the values of this tensor.
* The iteration order is the same as for cellIterator.
*/
@Override
public Iterator<Double> valueIterator() {
- return new Iterator<Double>() {
+ return new Iterator<>() {
Iterator<Cell> cellIterator = cellIterator();
@Override
public boolean hasNext() {
@@ -108,6 +113,20 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Cell cell : cellIterable()) {
+ TensorAddress address = cell.getKey();
+ double value = cell.getValue();
+ builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
+ }
+ for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
+ builder.cell(addCell.getKey(), addCell.getValue());
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 8002990e5c6..175e6b41daa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -113,6 +113,20 @@ public interface Tensor {
return builder.build();
}
+ /**
+ * Returns a new tensor where existing cells in this tensor have been
+ * modified according to the given operation and cells in the given map.
+ * In contrast to {@link #modify}, previously non-existing cells are added
+ * to this tensor. Only valid for sparse or mixed tensors.
+ *
+ * @param op how to update overlapping cells
+ * @param cells cells to merge with this tensor
+ * @return a new tensor where this tensor is merged with the other
+ */
+ Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells);
+
+// Tensor remove(Tensor other);
+
// ----------------- Primitive tensor functions
default Tensor map(DoubleUnaryOperator mapper) {