summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-02-20 14:30:31 +0100
committerLester Solbakken <lesters@oath.com>2019-02-20 14:30:31 +0100
commitc85a3fee56c13f82d14d480e7569432e1f352316 (patch)
tree1ba19b8b498a7c4e0004939a8139fcfbd8d75875
parent085b6922c07f4626c61e2ed2e6dde6beec0855de (diff)
TensorRemoveUpdate support for mixed tensors
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java32
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java11
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java15
-rw-r--r--document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java40
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java15
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java31
-rw-r--r--vespa-hadoop/abi-spec.json8
-rw-r--r--vespajlib/abi-spec.json4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java11
12 files changed, 166 insertions, 38 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 210a6a80ee5..0d12e7c074b 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -24,23 +24,23 @@ public class TensorRemoveUpdateReader {
static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
- expectTensorTypeIsSparse(field);
+ expectTensorTypeHasSparseDimensions(field);
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
- TensorType tensorType = tensorDataType.getTensorType();
+ TensorType originalType = tensorDataType.getTensorType();
+ TensorType convertedType = extractSparseDimensions(originalType);
- // TODO: for mixed case extract a new tensor type based only on mapped dimensions
-
- Tensor tensor = readRemoveUpdateTensor(buffer, tensorType);
+ Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
expectAddressesAreNonEmpty(field, tensor);
return new TensorRemoveUpdate(new TensorFieldValue(tensor));
}
- private static void expectTensorTypeIsSparse(Field field) {
+ private static void expectTensorTypeHasSparseDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- if (tensorType.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed)) {
- throw new IllegalArgumentException("A remove update can only be applied to sparse tensors. "
- + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
+ throw new IllegalArgumentException("A remove update can only be applied to tensors " +
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
@@ -53,7 +53,7 @@ public class TensorRemoveUpdateReader {
/**
* Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
*/
- private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) {
+ private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) {
Tensor.Builder builder = Tensor.Builder.of(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
@@ -62,7 +62,7 @@ public class TensorRemoveUpdateReader {
expectArrayStart(buffer.currentToken());
int nesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
- builder.cell(readTensorAddress(buffer, type), 1.0);
+ builder.cell(readTensorAddress(buffer, type, originalType), 1.0);
}
expectCompositeEnd(buffer.currentToken());
}
@@ -71,12 +71,15 @@ public class TensorRemoveUpdateReader {
return builder.build();
}
- private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) {
+ private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) {
TensorAddress.Builder builder = new TensorAddress.Builder(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
String dimension = buffer.currentName();
+ if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) {
+ throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update");
+ }
String label = buffer.currentText();
builder.add(dimension, label);
}
@@ -84,4 +87,9 @@ public class TensorRemoveUpdateReader {
return builder.build();
}
+ public static TensorType extractSparseDimensions(TensorType type) {
+ TensorType.Builder builder = new TensorType.Builder();
+ type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
+ return builder.build();
+ }
}
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index 2f22def9aa1..2ab7169fae2 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -5,6 +5,7 @@ import com.yahoo.document.DataType;
import com.yahoo.document.DocumentTypeManager;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
+import com.yahoo.document.json.readers.TensorRemoveUpdateReader;
import com.yahoo.document.update.TensorAddUpdate;
import com.yahoo.document.update.TensorModifyUpdate;
import com.yahoo.document.update.TensorRemoveUpdate;
@@ -46,7 +47,10 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
throw new DeserializationException("Expected tensor data type, got " + type);
}
TensorDataType tensorDataType = (TensorDataType)type;
- TensorFieldValue tensor = new TensorFieldValue(tensorDataType.getTensorType());
+ TensorType tensorType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(tensorType);
+
+ TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
return new TensorAddUpdate(tensor);
}
@@ -58,10 +62,9 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
}
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorRemoveUpdateReader.extractSparseDimensions(tensorType);
- // TODO: for mixed case extract a new tensor type based only on mapped dimensions
-
- TensorFieldValue tensor = new TensorFieldValue(tensorType);
+ TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
return new TensorRemoveUpdate(tensor);
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index e9fb1e3efd5..fb046f15c2c 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -51,17 +51,10 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
- Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get();
- Map<TensorAddress, Double> cellsToRemove = tensor.getTensor().get().cells();
- Tensor.Builder builder = Tensor.Builder.of(oldTensor.type());
- for (Iterator<Tensor.Cell> i = oldTensor.cellIterator(); i.hasNext(); ) {
- Tensor.Cell cell = i.next();
- TensorAddress address = cell.getKey();
- if ( ! cellsToRemove.containsKey(address)) {
- builder.cell(address, cell.getValue());
- }
- }
- return new TensorFieldValue(builder.build());
+ Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
+ Tensor update = tensor.getTensor().get();
+ Tensor result = old.remove(update.cells().keySet());
+ return new TensorFieldValue(result);
}
@Override
diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
index e2736dabd2b..01293cb9782 100644
--- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
+++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
@@ -40,6 +40,7 @@ public class DocumentUpdateJsonSerializerTest {
final static TensorType sparseTensorType = new TensorType.Builder().mapped("x").mapped("y").build();
final static TensorType denseTensorType = new TensorType.Builder().indexed("x", 2).indexed("y", 3).build();
+ final static TensorType mixedTensorType = new TensorType.Builder().mapped("x").indexed("y", 3).build();
final static DocumentTypeManager types = new DocumentTypeManager();
final static JsonFactory parserFactory = new JsonFactory();
final static DocumentType docType = new DocumentType("doctype");
@@ -60,6 +61,7 @@ public class DocumentUpdateJsonSerializerTest {
docType.addField(new Field("byte_field", DataType.BYTE));
docType.addField(new Field("sparse_tensor", new TensorDataType(sparseTensorType)));
docType.addField(new Field("dense_tensor", new TensorDataType(denseTensorType)));
+ docType.addField(new Field("mixed_tensor", new TensorDataType(mixedTensorType)));
docType.addField(new Field("reference_field", new ReferenceDataType(refTargetDocType, 777)));
docType.addField(new Field("predicate_field", DataType.PREDICATE));
docType.addField(new Field("raw_field", DataType.RAW));
@@ -355,6 +357,25 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_add_update_mixed() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'add': {",
+ " 'cells': [",
+ " { 'address': { 'x': '0', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': '1', 'y': '2' }, 'value': 3.0 }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void test_tensor_remove_update() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
@@ -374,6 +395,25 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_remove_update_mixed() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'remove': {",
+ " 'addresses': [",
+ " {'x':'0' }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+
+ @Test
public void reference_field_id_can_be_update_assigned_non_empty_id() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index a20276e5c65..fe24a755d1d 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1489,6 +1489,7 @@ public class JsonReaderTestCase {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Add update for field 'sparse_tensor' does not contain tensor cells");
createTensorAddUpdate(inputJson("{}"), "sparse_tensor");
+ createTensorAddUpdate(inputJson("{}"), "mixed_tensor");
}
@Test
@@ -1500,7 +1501,6 @@ public class JsonReaderTestCase {
" { 'x': 'c', 'y': 'd' } ]}"));
}
- @Ignore
@Test
public void tensor_remove_update_on_mixed_tensor() {
assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor",
@@ -1511,9 +1511,19 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() {
+ exception.expect(IllegalArgumentException.class);
+ exception.expectMessage("Indexed dimension address 'y' should not be specified in remove update");
+ createTensorRemoveUpdate(inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1', 'y': '0' },",
+ " { 'x': '2', 'y': '0' } ]}"), "mixed_tensor");
+ }
+
+ @Test
public void tensor_remove_update_on_dense_tensor_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
+ exception.expectMessage("A remove update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
createTensorRemoveUpdate(inputJson("{",
" 'addresses': [] }"), "dense_tensor");
}
@@ -1532,6 +1542,7 @@ public class JsonReaderTestCase {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Remove update for field 'sparse_tensor' does not contain tensor addresses");
createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "sparse_tensor");
+ createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "mixed_tensor");
}
@Test
diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
index 40ab00facdb..52ed6c63356 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
@@ -10,17 +10,32 @@ import static org.junit.Assert.assertEquals;
public class TensorRemoveUpdateTest {
@Test
- public void apply_remove_update_operations() {
- assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}");
- assertApplyTo("{}", "{{x:0,y:0}:1}", "{}");
- assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}");
+ public void apply_remove_update_operations_sparse() {
+ assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}");
+ assertSparseApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}");
+ assertSparseApplyTo("{}", "{{x:0,y:0}:1}", "{}");
+ assertSparseApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}");
}
- private void assertApplyTo(String init, String update, String expected) {
- String spec = "tensor(x{},y{})";
+ @Test
+ public void apply_remove_update_operations_mixed() {
+ assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0}:1}", "{}");
+ assertMixedApplyTo("{{x:0,y:0}:1, {x:1,y:0}:2}", "{{x:0}:1}", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}");
+ assertMixedApplyTo("{}", "{{x:0}:1}", "{}");
+ assertMixedApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}");
+ }
+
+ private void assertSparseApplyTo(String init, String update, String expected) {
+ assertApplyTo("tensor(x{},y{})", "tensor(x{},y{})", init, update, expected);
+ }
+
+ private void assertMixedApplyTo(String init, String update, String expected) {
+ assertApplyTo("tensor(x{},y[3])", "tensor(x{})", init, update, expected);
+ }
+
+ private void assertApplyTo(String spec, String updateSpec, String init, String update, String expected) {
TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init));
- TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(spec, update)));
+ TensorRemoveUpdate removeUpdate = new TensorRemoveUpdate(new TensorFieldValue(Tensor.from(updateSpec, update)));
TensorFieldValue updatedFieldValue = (TensorFieldValue) removeUpdate.applyTo(initialFieldValue);
assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get());
}
diff --git a/vespa-hadoop/abi-spec.json b/vespa-hadoop/abi-spec.json
index 5bbac15f0e5..e3f4dcf272a 100644
--- a/vespa-hadoop/abi-spec.json
+++ b/vespa-hadoop/abi-spec.json
@@ -1201,6 +1201,8 @@
"public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1245,6 +1247,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)"
@@ -1330,6 +1334,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1432,6 +1438,8 @@
"public double asDouble()",
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 480523982fa..c3fe8c5c7ad 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -809,6 +809,7 @@
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -854,6 +855,7 @@
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)"
@@ -940,6 +942,7 @@
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1043,6 +1046,7 @@
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 704cead7c01..38d832d01c2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -196,6 +196,11 @@ public class IndexedTensor implements Tensor {
}
@Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ throw new IllegalArgumentException("Remove is not supported for indexed tensors");
+ }
+
+ @Override
public int hashCode() { return Arrays.hashCode(values); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index f44b3ce13b7..22ceed22d3e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.Map;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
@@ -71,6 +72,19 @@ public class MappedTensor implements Tensor {
}
@Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ TensorAddress address = cell.getKey();
+ if ( ! addresses.contains(address)) {
+ builder.cell(address, cell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 3630a016691..00229c56171 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -6,9 +6,11 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.stream.Collectors;
@@ -127,6 +129,22 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) {
+ TensorAddress sparsePartialAddress = entry.getKey();
+ if ( ! addresses.contains(sparsePartialAddress)) {
+ long offset = entry.getValue();
+ for (int i = 0; i < index.denseSubspaceSize; ++i) {
+ Cell cell = cells.get((int)offset + i);
+ builder.cell(cell.getKey(), cell.getValue());
+ }
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 175e6b41daa..a2333f41135 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -25,6 +25,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
@@ -125,7 +126,15 @@ public interface Tensor {
*/
Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells);
-// Tensor remove(Tensor other);
+ /**
+ * Returns a new tensor where existing cells in this tensor have been
+ * removed according to the given set of addresses. Only valid for sparse
+ * or mixed tensors.
+ *
+ * @param addresses list of addresses to remove
+ * @return a new tensor where cells have been removed
+ */
+ Tensor remove(Set<TensorAddress> addresses);
// ----------------- Primitive tensor functions