summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2020-11-17 15:57:17 +0100
committerGitHub <noreply@github.com>2020-11-17 15:57:17 +0100
commit2571701816682b8d5989dc49bac7c5441feca213 (patch)
treeb8ecb76cbd5bdbb96f68420e2077d6a0ce398d58
parent8a4e20a2542e5d9407ca474e2a1e30902bc4158b (diff)
parentcf02c8777d8bff26b2f1cc73e342c38945b7c94c (diff)
Merge pull request #15368 from vespa-engine/geirst/extend-tensor-remove-update
Extend tensor remove update to handle not fully specified addresses
-rw-r--r--document/abi-spec.json3
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java44
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java66
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java2
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java8
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java28
-rw-r--r--document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java19
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java39
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java25
-rw-r--r--document/src/tests/documentupdatetestcase.cpp9
-rw-r--r--document/src/vespa/document/base/testdocrepo.cpp1
-rw-r--r--document/src/vespa/document/serialization/vespadocumentdeserializer.cpp9
-rw-r--r--document/src/vespa/document/serialization/vespadocumentdeserializer.h2
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp40
-rw-r--r--eval/src/tests/tensor/partial_remove/partial_remove_test.cpp24
-rw-r--r--eval/src/vespa/eval/tensor/partial_update.cpp78
16 files changed, 326 insertions, 71 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json
index c9191aa2fdb..b119f9991b3 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -3150,9 +3150,11 @@
"public"
],
"methods": [
+ "public void <init>()",
"public void <init>(com.yahoo.tensor.TensorType)",
"public void <init>(com.yahoo.tensor.Tensor)",
"public java.util.Optional getTensor()",
+ "public java.util.Optional getTensorType()",
"public com.yahoo.document.TensorDataType getDataType()",
"public java.lang.String toString()",
"public void printXml(com.yahoo.document.serialization.XmlStream)",
@@ -4379,6 +4381,7 @@
],
"methods": [
"public void <init>(com.yahoo.document.datatypes.TensorFieldValue)",
+ "public void verifyCompatibleType(com.yahoo.tensor.TensorType)",
"protected void checkCompatibility(com.yahoo.document.DataType)",
"public void serialize(com.yahoo.document.serialization.DocumentUpdateWriter, com.yahoo.document.DataType)",
"public com.yahoo.document.datatypes.FieldValue applyTo(com.yahoo.document.datatypes.FieldValue)",
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
index 2c6a556c652..8e7dbd3512a 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
@@ -20,17 +20,27 @@ public class TensorFieldValue extends FieldValue {
private Optional<Tensor> tensor;
- private final TensorDataType dataType;
+ private Optional<TensorDataType> dataType;
+
+ /**
+ * Create an empty tensor field value where the tensor type is not yet known.
+ *
+ * The tensor (and tensor type) can later be assigned with assignTensor().
+ */
+ public TensorFieldValue() {
+ this.dataType = Optional.empty();
+ this.tensor = Optional.empty();
+ }
- /** Create an empty tensor field value */
+ /** Create an empty tensor field value for the given tensor type */
public TensorFieldValue(TensorType type) {
- this.dataType = new TensorDataType(type);
+ this.dataType = Optional.of(new TensorDataType(type));
this.tensor = Optional.empty();
}
/** Create a tensor field value containing the given tensor */
public TensorFieldValue(Tensor tensor) {
- this.dataType = new TensorDataType(tensor.type());
+ this.dataType = Optional.of(new TensorDataType(tensor.type()));
this.tensor = Optional.of(tensor);
}
@@ -38,9 +48,13 @@ public class TensorFieldValue extends FieldValue {
return tensor;
}
+ public Optional<TensorType> getTensorType() {
+ return dataType.isPresent() ? Optional.of(dataType.get().getTensorType()) : Optional.empty();
+ }
+
@Override
public TensorDataType getDataType() {
- return dataType;
+ return dataType.get();
}
@Override
@@ -76,10 +90,22 @@ public class TensorFieldValue extends FieldValue {
}
}
+ /**
+ * Assigns the given tensor to this field value.
+ *
+ * The tensor type is also set from the given tensor if it was not set before.
+ */
public void assignTensor(Optional<Tensor> tensor) {
- if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType()))
- throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
- " to field of type " + dataType.getTensorType());
+ if (tensor.isPresent()) {
+ if (getTensorType().isPresent() &&
+ !tensor.get().type().isAssignableTo(getTensorType().get())) {
+ throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
+ " to field of type " + getTensorType().get());
+ }
+ if (getTensorType().isEmpty()) {
+ this.dataType = Optional.of(new TensorDataType(tensor.get().type()));
+ }
+ }
this.tensor = tensor;
}
@@ -99,7 +125,7 @@ public class TensorFieldValue extends FieldValue {
if ( ! (o instanceof TensorFieldValue)) return false;
TensorFieldValue other = (TensorFieldValue)o;
- if ( ! dataType.getTensorType().equals(other.dataType.getTensorType())) return false;
+ if ( ! getTensorType().equals(other.getTensorType())) return false;
if ( ! tensor.equals(other.tensor)) return false;
return true;
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 91c275b6da0..cffc85777dc 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -1,6 +1,7 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document.json.readers;
+import com.yahoo.collections.Pair;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
@@ -10,6 +11,8 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.util.HashMap;
+
import static com.yahoo.document.json.readers.JsonParserHelpers.expectArrayStart;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectCompositeEnd;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd;
@@ -28,8 +31,8 @@ public class TensorRemoveUpdateReader {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
- TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType);
- Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
+ TensorType sparseType = TensorRemoveUpdate.extractSparseDimensions(originalType);
+ Tensor tensor = readRemoveUpdateTensor(buffer, sparseType, originalType);
expectAddressesAreNonEmpty(field, tensor);
return new TensorRemoveUpdate(new TensorFieldValue(tensor));
@@ -54,8 +57,8 @@ public class TensorRemoveUpdateReader {
/**
* Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
*/
- private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) {
- Tensor.Builder builder = Tensor.Builder.of(type);
+ private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType sparseType, TensorType originalType) {
+ Tensor.Builder builder = null;
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
@@ -63,13 +66,55 @@ public class TensorRemoveUpdateReader {
expectArrayStart(buffer.currentToken());
int nesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
- builder.cell(readTensorAddress(buffer, type, originalType), 1.0);
+ if (builder == null) {
+ var typeAndAddress = readFirstTensorAddress(buffer, sparseType, originalType);
+ builder = Tensor.Builder.of(typeAndAddress.getFirst());
+ builder.cell(typeAndAddress.getSecond(), 1.0);
+ } else {
+ builder.cell(readTensorAddress(buffer, builder.type(), originalType), 1.0);
+ }
}
expectCompositeEnd(buffer.currentToken());
}
}
expectObjectEnd(buffer.currentToken());
- return builder.build();
+ return (builder != null) ? builder.build() : Tensor.Builder.of(sparseType).build();
+ }
+
+ /**
+ * Reads the first raw tensor address from the given buffer and resolves and returns the tensor type and tensor address based on this.
+ * The resulting tensor type contains a subset or all of the dimensions from the given sparseType.
+ */
+ private static Pair<TensorType, TensorAddress> readFirstTensorAddress(TokenBuffer buffer, TensorType sparseType, TensorType originalType) {
+ var typeBuilder = new TensorType.Builder(sparseType.valueType());
+ var rawAddress = new HashMap<String, String>();
+ expectObjectStart(buffer.currentToken());
+ int initNesting = buffer.nesting();
+ for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
+ var elem = readRawElement(buffer, sparseType, originalType);
+ var dimension = sparseType.dimension(elem.getFirst());
+ if (dimension.isPresent()) {
+ typeBuilder.dimension(dimension.get());
+ rawAddress.put(elem.getFirst(), elem.getSecond());
+ } else {
+ throw new IllegalArgumentException(originalType + " does not contain dimension '" + elem.getFirst() + "'");
+ }
+ }
+ expectObjectEnd(buffer.currentToken());
+ var type = typeBuilder.build();
+ var builder = new TensorAddress.Builder(type);
+ rawAddress.forEach((dimension, label) -> builder.add(dimension, label));
+ return new Pair<>(type, builder.build());
+ }
+
+ private static Pair<String, String> readRawElement(TokenBuffer buffer, TensorType type, TensorType originalType) {
+ String dimension = buffer.currentName();
+ if (type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) {
+ throw new IllegalArgumentException("Indexed dimension address '" + dimension +
+ "' should not be specified in remove update");
+ }
+ String label = buffer.currentText();
+ return new Pair<>(dimension, label);
}
private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) {
@@ -77,13 +122,8 @@ public class TensorRemoveUpdateReader {
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
- String dimension = buffer.currentName();
- if ( type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) {
- throw new IllegalArgumentException("Indexed dimension address '" + dimension +
- "' should not be specified in remove update");
- }
- String label = buffer.currentText();
- builder.add(dimension, label);
+ var elem = readRawElement(buffer, type, originalType);
+ builder.add(elem.getFirst(), elem.getSecond());
}
expectObjectEnd(buffer.currentToken());
return builder.build();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java
index cac05fb7879..92b3b566b85 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java
@@ -243,7 +243,7 @@ public class VespaDocumentDeserializer6 extends BufferSerializer implements Docu
int encodedTensorLength = buf.getInt1_4Bytes();
if (encodedTensorLength > 0) {
byte[] encodedTensor = getBytes(null, encodedTensorLength);
- value.assign(TypedBinaryFormat.decode(Optional.of(value.getDataType().getTensorType()),
+ value.assign(TypedBinaryFormat.decode(value.getTensorType(),
GrowableByteBuffer.wrap(encodedTensor)));
} else {
value.clear();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index 58c50f047f9..e7f1525ff81 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -61,10 +61,12 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
}
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
- TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType);
- TensorFieldValue tensor = new TensorFieldValue(convertedType);
+ TensorFieldValue tensor = new TensorFieldValue();
tensor.deserialize(this);
- return new TensorRemoveUpdate(tensor);
+ var result = new TensorRemoveUpdate(tensor);
+ result.verifyCompatibleType(tensorType);
+ return result;
}
+
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index 981120af145..a300565391f 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -12,9 +12,10 @@ import com.yahoo.tensor.TensorType;
import java.util.Objects;
/**
- * An update used to remove cells from a sparse tensor (has only mapped dimensions).
+ * An update used to remove cells from a sparse tensor or dense sub-spaces from a mixed tensor.
*
- * The cells to remove are contained in a sparse tensor where cell values are set to 1.0
+ * The specification of which cells to remove contains addresses using a subset or all of the sparse dimensions of the tensor type.
+ * This is represented as a sparse tensor where cell values are set to 1.0.
*/
public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
@@ -23,17 +24,20 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
public TensorRemoveUpdate(TensorFieldValue value) {
super(ValueUpdateClassID.TENSORREMOVE);
this.tensor = value;
- verifyCompatibleType();
- }
-
- private void verifyCompatibleType() {
- if ( ! tensor.getTensor().isPresent()) {
+ if (!tensor.getTensor().isPresent()) {
throw new IllegalArgumentException("Tensor must be present in remove update");
}
- TensorType tensorType = tensor.getTensor().get().type();
- TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType());
- if ( ! tensorType.equals(expectedType)) {
- throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'");
+ verifyCompatibleType(tensor.getTensorType().get());
+ }
+
+ public void verifyCompatibleType(TensorType originalType) {
+ TensorType sparseType = extractSparseDimensions(originalType);
+ TensorType thisType = tensor.getTensorType().get();
+ for (var dim : thisType.dimensions()) {
+ if (sparseType.dimension(dim.name()).isEmpty()) {
+ throw new IllegalArgumentException("Unexpected type '" + thisType + "' in remove update. "
+ + "Expected dimensions to be a subset of '" + sparseType + "'");
+ }
}
}
@@ -63,6 +67,7 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
Tensor update = tensor.getTensor().get();
+ // TODO: handle the case where this tensor only contains a subset of the sparse dimensions of the input tensor.
Tensor result = old.remove(update.cells().keySet());
return new TensorFieldValue(result);
}
@@ -102,5 +107,4 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return builder.build();
}
-
}
diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
index 4e8fa427e7d..1772a410a36 100644
--- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
+++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
@@ -436,6 +436,25 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_remove_update_with_not_fully_specified_address() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'sparse_tensor': {",
+ " 'remove': {",
+ " 'addresses': [",
+ " {'y':'0'},",
+ " {'y':'2'}",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void reference_field_id_can_be_update_assigned_non_empty_id() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 7fc43656d55..da9ab4ea7bf 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -168,6 +168,8 @@ public class JsonReaderTestCase {
new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build())));
x.addField(new Field("mixed_tensor",
new TensorDataType(new TensorType.Builder().mapped("x").indexed("y", 3).build())));
+ x.addField(new Field("mixed_tensor_adv",
+ new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").mapped("z").indexed("a", 3).build())));
types.registerDocumentType(x);
}
{
@@ -1685,6 +1687,24 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_remove_update_on_sparse_tensor_with_not_fully_specified_address() {
+ assertTensorRemoveUpdate("{{y:b}:1.0,{y:d}:1.0}", "sparse_tensor",
+ inputJson("{",
+ " 'addresses': [",
+ " { 'y': 'b' },",
+ " { 'y': 'd' } ]}"));
+ }
+
+ @Test
+ public void tensor_remove_update_on_mixed_tensor_with_not_fully_specified_address() {
+ assertTensorRemoveUpdate("{{x:1,z:a}:1.0,{x:2,z:b}:1.0}", "mixed_tensor_adv",
+ inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1', 'z': 'a' },",
+ " { 'x': '2', 'z': 'b' } ]}"));
+ }
+
+ @Test
public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() {
illegalTensorRemoveUpdate("Error in 'mixed_tensor': Indexed dimension address 'y' should not be specified in remove update",
"mixed_tensor",
@@ -1703,12 +1723,19 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_remove_update_on_not_fully_specified_cell_throws() {
- illegalTensorRemoveUpdate("Error in 'sparse_tensor': Missing a label for dimension y for tensor(x{},y{})",
- "sparse_tensor",
- "{",
- " 'addresses': [",
- " { 'x': 'a' } ]}");
+ public void tensor_remove_update_with_stray_dimension_throws() {
+ illegalTensorRemoveUpdate("Error in 'sparse_tensor': tensor(x{},y{}) does not contain dimension 'foo'",
+ "sparse_tensor",
+ "{",
+ " 'addresses': [",
+ " { 'x': 'a', 'foo': 'b' } ]}");
+
+ illegalTensorRemoveUpdate("Error in 'sparse_tensor': tensor(x{}) does not contain dimension 'foo'",
+ "sparse_tensor",
+ "{",
+ " 'addresses': [",
+ " { 'x': 'c' },",
+ " { 'x': 'a', 'foo': 'b' } ]}");
}
@Test
diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
index 3a005e858c8..86f07db1b2d 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
@@ -3,9 +3,12 @@ package com.yahoo.document.update;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.yolean.Exceptions;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
public class TensorRemoveUpdateTest {
@@ -22,4 +25,26 @@ public class TensorRemoveUpdateTest {
assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get());
}
+ @Test
+ public void verify_compatible_type_throws_on_mismatch() {
+ // Contains an indexed dimension, which is not allowed.
+ illegalTensorRemoveUpdate("tensor(x{},y[1])", "{{x:a,y:0}:1}", "tensor(x{},y[1])",
+ "Unexpected type 'tensor(x{},y[1])' in remove update. Expected dimensions to be a subset of 'tensor(x{})'");
+
+ // Sparse dimension is not found in the original type.
+ illegalTensorRemoveUpdate("tensor(y{})", "{{y:a}:1}", "tensor(x{},z[2])",
+ "Unexpected type 'tensor(y{})' in remove update. Expected dimensions to be a subset of 'tensor(x{})'");
+ }
+
+ private void illegalTensorRemoveUpdate(String updateType, String updateTensor, String originalType, String expectedMessage) {
+ try {
+ var value = new TensorFieldValue(Tensor.from(updateType, updateTensor));
+ new TensorRemoveUpdate(value).verifyCompatibleType(TensorType.fromSpec(originalType));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals(expectedMessage, Exceptions.toMessageString(expected));
+ }
+ }
+
}
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 5fd62957f65..b88a0437fc2 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -1031,6 +1031,13 @@ TEST(DocumentUpdateTest, tensor_remove_update_can_be_roundtrip_serialized)
f.assertRoundtripSerialize(TensorRemoveUpdate(f.makeBaselineTensor()));
}
+TEST(DocumentUpdateTest, tensor_remove_update_with_not_fully_specified_address_can_be_roundtrip_serialized)
+{
+ TensorUpdateFixture f("sparse_xy_tensor");
+ TensorDataType type(ValueType::from_spec("tensor(y{})"));
+ f.assertRoundtripSerialize(TensorRemoveUpdate(
+ makeTensorFieldValue(TensorSpec("tensor(y{})").add({{"y", "a"}}, 1), type)));
+}
TEST(DocumentUpdateTest, tensor_remove_update_on_float_tensor_can_be_roundtrip_serialized)
{
@@ -1087,7 +1094,7 @@ TEST(DocumentUpdateTest, tensor_remove_update_throws_if_address_tensor_is_not_sp
auto addressTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense address tensor
ASSERT_THROW(
f.assertRoundtripSerialize(TensorRemoveUpdate(std::move(addressTensor))),
- document::WrongTensorTypeException);
+ vespalib::IllegalStateException);
}
TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_sparse)
diff --git a/document/src/vespa/document/base/testdocrepo.cpp b/document/src/vespa/document/base/testdocrepo.cpp
index 58d5a30ec35..24625c6f667 100644
--- a/document/src/vespa/document/base/testdocrepo.cpp
+++ b/document/src/vespa/document/base/testdocrepo.cpp
@@ -53,6 +53,7 @@ DocumenttypesConfig TestDocRepo::getDefaultConfig() {
.addField("rawarray", Array(DataType::T_RAW))
.addField("structarray", structarray_id)
.addTensorField("sparse_tensor", "tensor(x{})")
+ .addTensorField("sparse_xy_tensor", "tensor(x{},y{})")
.addTensorField("sparse_float_tensor", "tensor<float>(x{})")
.addTensorField("dense_tensor", "tensor(x[2])"));
builder.document(type2_id, "testdoctype2",
diff --git a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp
index 6ec9c52281f..eaa5a484ad1 100644
--- a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp
+++ b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp
@@ -355,10 +355,15 @@ void VespaDocumentDeserializer::read(WeightedSetFieldValue &value) {
}
}
-
void
VespaDocumentDeserializer::read(TensorFieldValue &value)
{
+ value.assignDeserialized(readTensor());
+}
+
+std::unique_ptr<vespalib::eval::Value>
+VespaDocumentDeserializer::readTensor()
+{
size_t length = _stream.getInt1_4Bytes();
if (length > _stream.size()) {
throw DeserializeException(vespalib::make_string("Stream failed size(%zu), needed(%zu) to deserialize tensor field value", _stream.size(), length),
@@ -372,8 +377,8 @@ VespaDocumentDeserializer::read(TensorFieldValue &value)
throw DeserializeException("Leftover bytes deserializing tensor field value.", VESPA_STRLOC);
}
}
- value.assignDeserialized(std::move(tensor));
_stream.adjustReadPos(length);
+ return tensor;
}
void VespaDocumentDeserializer::read(ReferenceFieldValue& value) {
diff --git a/document/src/vespa/document/serialization/vespadocumentdeserializer.h b/document/src/vespa/document/serialization/vespadocumentdeserializer.h
index e6b490e1075..6792914d9da 100644
--- a/document/src/vespa/document/serialization/vespadocumentdeserializer.h
+++ b/document/src/vespa/document/serialization/vespadocumentdeserializer.h
@@ -7,6 +7,7 @@
#include <memory>
namespace vespalib { class nbostream; }
+namespace vespalib::eval { class Value; }
namespace document {
class DocumentId;
@@ -78,6 +79,7 @@ public:
void readStructNoReset(StructFieldValue &value);
void read(WeightedSetFieldValue &value);
void read(TensorFieldValue &value);
+ std::unique_ptr<vespalib::eval::Value> readTensor();
void read(ReferenceFieldValue& value);
};
} // namespace document
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index 5d85b8956fa..688f9cf5399 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -20,6 +20,7 @@
using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
using vespalib::make_string;
+using vespalib::eval::Value;
using vespalib::eval::ValueType;
using vespalib::eval::EngineOrFactory;
using vespalib::tensor::TensorPartialUpdate;
@@ -157,38 +158,47 @@ TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &in
namespace {
void
-verifyAddressTensorIsSparse(const vespalib::eval::Value *addressTensor)
+verifyAddressTensorIsSparse(const Value *addressTensor)
{
if (addressTensor == nullptr) {
- return;
+ throw IllegalStateException("Address tensor is not set", VESPA_STRLOC);
}
auto engine = EngineOrFactory::get();
if (TensorPartialUpdate::check_suitably_sparse(*addressTensor, engine)) {
return;
}
- vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'",
- addressTensor->type().to_spec().c_str());
+ auto err = make_string("Expected address tensor to be sparse, but has type '%s'",
+ addressTensor->type().to_spec().c_str());
throw IllegalStateException(err, VESPA_STRLOC);
}
+void
+verify_tensor_type_dimensions_are_subset_of(const ValueType& lhs_type,
+ const ValueType& rhs_type)
+{
+ for (const auto& dim : lhs_type.dimensions()) {
+ if (rhs_type.dimension_index(dim.name) == ValueType::Dimension::npos) {
+ auto err = make_string("Unexpected type '%s' for address tensor. "
+ "Expected dimensions to be a subset of '%s'",
+ lhs_type.to_spec().c_str(), rhs_type.to_spec().c_str());
+ throw IllegalStateException(err, VESPA_STRLOC);
+ }
+ }
+}
}
void
TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream)
{
- _tensorType = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type));
- auto tensor = _tensorType->createFieldValue();
- if (tensor->inherits(TensorFieldValue::classId)) {
- _tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
- } else {
- vespalib::string err = make_string("Expected tensor field value, got a '%s' field value",
- tensor->getClass().name());
- throw IllegalStateException(err, VESPA_STRLOC);
- }
VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion());
- deserializer.read(*_tensor);
- verifyAddressTensorIsSparse(_tensor->getAsTensorPtr());
+ auto tensor = deserializer.readTensor();
+ verifyAddressTensorIsSparse(tensor.get());
+ auto compatible_type = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type));
+ verify_tensor_type_dimensions_are_subset_of(tensor->type(), compatible_type->getTensorType());
+ _tensorType = std::make_unique<const TensorDataType>(tensor->type());
+ _tensor = std::make_unique<TensorFieldValue>(*_tensorType);
+ _tensor->assignDeserialized(std::move(tensor));
}
TensorRemoveUpdate *
diff --git a/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp b/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp
index 220eee0ba8f..e182fffa890 100644
--- a/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp
+++ b/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp
@@ -116,4 +116,28 @@ TEST(PartialRemoveTest, partial_remove_returns_nullptr_on_invalid_inputs) {
}
}
+void
+expect_partial_remove(const TensorSpec& input, const TensorSpec& remove, const TensorSpec& exp)
+{
+ auto act = perform_partial_remove(input, remove);
+ EXPECT_EQ(exp, act);
+}
+
+TEST(PartialRemoveTest, remove_where_address_is_not_fully_specified) {
+ auto input = TensorSpec("tensor(x{},y{})").
+ add({{"x", "a"},{"y", "c"}}, 3.0).
+ add({{"x", "a"},{"y", "d"}}, 5.0).
+ add({{"x", "b"},{"y", "c"}}, 7.0);
+
+ expect_partial_remove(input,TensorSpec("tensor(x{})").add({{"x", "a"}}, 1.0),
+ TensorSpec("tensor(x{},y{})").add({{"x", "b"},{"y", "c"}}, 7.0));
+
+ expect_partial_remove(input, TensorSpec("tensor(y{})").add({{"y", "c"}}, 1.0),
+ TensorSpec("tensor(x{},y{})").add({{"x", "a"},{"y", "d"}}, 5.0));
+
+ expect_partial_remove(input, TensorSpec("tensor(y{})").add({{"y", "d"}}, 1.0),
+ TensorSpec("tensor(x{},y{})").add({{"x", "a"},{"y", "c"}}, 3.0)
+ .add({{"x", "b"},{"y", "c"}}, 7.0));
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/tensor/partial_update.cpp b/eval/src/vespa/eval/tensor/partial_update.cpp
index 014ffeb2666..fa15b2a38ae 100644
--- a/eval/src/vespa/eval/tensor/partial_update.cpp
+++ b/eval/src/vespa/eval/tensor/partial_update.cpp
@@ -298,31 +298,91 @@ struct PerformRemove {
const ValueBuilderFactory &factory);
};
+/**
+ * Calculates the indexes of where the mapped modifier dimensions are found in the mapped input dimensions.
+ *
+ * The modifier dimensions should be a subset or all of the input dimensions.
+ * An empty vector is returned on type mismatch.
+ */
+std::vector<size_t>
+calc_mapped_dimension_indexes(const ValueType& input_type,
+ const ValueType& modifier_type)
+{
+ auto input_dims = input_type.mapped_dimensions();
+ auto mod_dims = modifier_type.mapped_dimensions();
+ if (mod_dims.size() > input_dims.size()) {
+ return {};
+ }
+ std::vector<size_t> result(mod_dims.size());
+ size_t j = 0;
+ for (size_t i = 0; i < mod_dims.size(); ++i) {
+ while ((j < input_dims.size()) && (input_dims[j] != mod_dims[i])) {
+ ++j;
+ }
+ if (j >= input_dims.size()) {
+ return {};
+ }
+ result[i] = j;
+ }
+ return result;
+}
+
+struct ModifierCoords {
+
+ std::vector<const vespalib::stringref *> lookup_refs;
+ std::vector<size_t> lookup_view_dims;
+
+ ModifierCoords(const SparseCoords& input_coords,
+ const std::vector<size_t>& input_dim_indexes,
+ const ValueType& modifier_type)
+ : lookup_refs(modifier_type.dimensions().size()),
+ lookup_view_dims(modifier_type.dimensions().size())
+ {
+ assert(modifier_type.dimensions().size() == input_dim_indexes.size());
+ for (size_t i = 0; i < input_dim_indexes.size(); ++i) {
+ // Setup the modifier dimensions to point to the matching input dimensions.
+ lookup_refs[i] = &input_coords.addr[input_dim_indexes[i]];
+ lookup_view_dims[i] = i;
+ }
+ }
+ ~ModifierCoords() {}
+};
+
template <typename ICT>
Value::UP
PerformRemove::invoke(const Value &input, const Value &modifier, const ValueBuilderFactory &factory)
{
const ValueType &input_type = input.type();
const ValueType &modifier_type = modifier.type();
- if (input_type.mapped_dimensions() != modifier_type.dimensions()) {
- LOG(error, "when removing cells from a tensor, mapped dimensions must be equal. "
- "Got input type %s versus modifier type %s",
- input_type.to_spec().c_str(), modifier_type.to_spec().c_str());
- return {};
- }
const size_t num_mapped_in_input = input_type.count_mapped_dimensions();
if (num_mapped_in_input == 0) {
- LOG(error, "cannot remove cells from a dense tensor of type %s",
+ LOG(error, "Cannot remove cells from a dense input tensor of type %s",
input_type.to_spec().c_str());
return {};
}
+ if (modifier_type.count_indexed_dimensions() != 0) {
+ LOG(error, "Cannot remove cells using a modifier tensor of type %s",
+ modifier_type.to_spec().c_str());
+ return {};
+ }
+ auto input_dim_indexes = calc_mapped_dimension_indexes(input_type, modifier_type);
+ if (input_dim_indexes.empty()) {
+ LOG(error, "Tensor type mismatch when removing cells from a tensor. "
+ "Got input type %s versus modifier type %s",
+ input_type.to_spec().c_str(), modifier_type.to_spec().c_str());
+ return {};
+ }
SparseCoords addrs(num_mapped_in_input);
- auto modifier_view = modifier.index().create_view(addrs.lookup_view_dims);
+ ModifierCoords mod_coords(addrs, input_dim_indexes, modifier_type);
+ auto modifier_view = modifier.index().create_view(mod_coords.lookup_view_dims);
const size_t expected_subspaces = input.index().size();
const size_t dsss = input_type.dense_subspace_size();
auto builder = factory.create_value_builder<ICT>(input_type, num_mapped_in_input, dsss, expected_subspaces);
auto filter_by_modifier = [&] (const auto & lookup_refs, size_t) {
- modifier_view->lookup(lookup_refs);
+ // The modifier dimensions are setup to point to the input dimensions address storage in ModifierCoords,
+ // so we don't need to use the lookup_refs argument.
+ (void) lookup_refs;
+ modifier_view->lookup(mod_coords.lookup_refs);
size_t modifier_subspace_index;
return !(modifier_view->next_result({}, modifier_subspace_index));
};