summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2020-11-11 17:06:23 +0000
committerGeir Storli <geirst@verizonmedia.com>2020-11-17 12:57:49 +0000
commit4cd2c6a1d4d2ab7337678931271a815b535ce518 (patch)
treeede014280efb581bc841c7ab7702c0a7b9e028a6
parente4c14623ad4ecbe6337a49d2176621c528bf7c22 (diff)
Extend tensor remove update to support not fully specified addresses and update JSON parser.
Previously, all the sparse dimensions of the sparse or mixed tensor type (to remove from) had to be specified in the addresses to remove.
-rw-r--r--document/abi-spec.json3
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java44
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java66
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java2
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java8
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java28
-rw-r--r--document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java19
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java39
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java25
-rw-r--r--document/src/tests/documentupdatetestcase.cpp9
-rw-r--r--document/src/vespa/document/base/testdocrepo.cpp1
-rw-r--r--document/src/vespa/document/serialization/vespadocumentdeserializer.cpp9
-rw-r--r--document/src/vespa/document/serialization/vespadocumentdeserializer.h2
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp40
14 files changed, 233 insertions, 62 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json
index c9191aa2fdb..b119f9991b3 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -3150,9 +3150,11 @@
"public"
],
"methods": [
+ "public void <init>()",
"public void <init>(com.yahoo.tensor.TensorType)",
"public void <init>(com.yahoo.tensor.Tensor)",
"public java.util.Optional getTensor()",
+ "public java.util.Optional getTensorType()",
"public com.yahoo.document.TensorDataType getDataType()",
"public java.lang.String toString()",
"public void printXml(com.yahoo.document.serialization.XmlStream)",
@@ -4379,6 +4381,7 @@
],
"methods": [
"public void <init>(com.yahoo.document.datatypes.TensorFieldValue)",
+ "public void verifyCompatibleType(com.yahoo.tensor.TensorType)",
"protected void checkCompatibility(com.yahoo.document.DataType)",
"public void serialize(com.yahoo.document.serialization.DocumentUpdateWriter, com.yahoo.document.DataType)",
"public com.yahoo.document.datatypes.FieldValue applyTo(com.yahoo.document.datatypes.FieldValue)",
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
index 2c6a556c652..8e7dbd3512a 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
@@ -20,17 +20,27 @@ public class TensorFieldValue extends FieldValue {
private Optional<Tensor> tensor;
- private final TensorDataType dataType;
+ private Optional<TensorDataType> dataType;
+
+ /**
+ * Create an empty tensor field value where the tensor type is not yet known.
+ *
+ * The tensor (and tensor type) can later be assigned with assignTensor().
+ */
+ public TensorFieldValue() {
+ this.dataType = Optional.empty();
+ this.tensor = Optional.empty();
+ }
- /** Create an empty tensor field value */
+ /** Create an empty tensor field value for the given tensor type */
public TensorFieldValue(TensorType type) {
- this.dataType = new TensorDataType(type);
+ this.dataType = Optional.of(new TensorDataType(type));
this.tensor = Optional.empty();
}
/** Create a tensor field value containing the given tensor */
public TensorFieldValue(Tensor tensor) {
- this.dataType = new TensorDataType(tensor.type());
+ this.dataType = Optional.of(new TensorDataType(tensor.type()));
this.tensor = Optional.of(tensor);
}
@@ -38,9 +48,13 @@ public class TensorFieldValue extends FieldValue {
return tensor;
}
+ public Optional<TensorType> getTensorType() {
+ return dataType.isPresent() ? Optional.of(dataType.get().getTensorType()) : Optional.empty();
+ }
+
@Override
public TensorDataType getDataType() {
- return dataType;
+ return dataType.get();
}
@Override
@@ -76,10 +90,22 @@ public class TensorFieldValue extends FieldValue {
}
}
+ /**
+ * Assigns the given tensor to this field value.
+ *
+ * The tensor type is also set from the given tensor if it was not set before.
+ */
public void assignTensor(Optional<Tensor> tensor) {
- if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType()))
- throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
- " to field of type " + dataType.getTensorType());
+ if (tensor.isPresent()) {
+ if (getTensorType().isPresent() &&
+ !tensor.get().type().isAssignableTo(getTensorType().get())) {
+ throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
+ " to field of type " + getTensorType().get());
+ }
+ if (getTensorType().isEmpty()) {
+ this.dataType = Optional.of(new TensorDataType(tensor.get().type()));
+ }
+ }
this.tensor = tensor;
}
@@ -99,7 +125,7 @@ public class TensorFieldValue extends FieldValue {
if ( ! (o instanceof TensorFieldValue)) return false;
TensorFieldValue other = (TensorFieldValue)o;
- if ( ! dataType.getTensorType().equals(other.dataType.getTensorType())) return false;
+ if ( ! getTensorType().equals(other.getTensorType())) return false;
if ( ! tensor.equals(other.tensor)) return false;
return true;
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 91c275b6da0..cffc85777dc 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -1,6 +1,7 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document.json.readers;
+import com.yahoo.collections.Pair;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
@@ -10,6 +11,8 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.util.HashMap;
+
import static com.yahoo.document.json.readers.JsonParserHelpers.expectArrayStart;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectCompositeEnd;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd;
@@ -28,8 +31,8 @@ public class TensorRemoveUpdateReader {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
- TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType);
- Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
+ TensorType sparseType = TensorRemoveUpdate.extractSparseDimensions(originalType);
+ Tensor tensor = readRemoveUpdateTensor(buffer, sparseType, originalType);
expectAddressesAreNonEmpty(field, tensor);
return new TensorRemoveUpdate(new TensorFieldValue(tensor));
@@ -54,8 +57,8 @@ public class TensorRemoveUpdateReader {
/**
* Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
*/
- private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) {
- Tensor.Builder builder = Tensor.Builder.of(type);
+ private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType sparseType, TensorType originalType) {
+ Tensor.Builder builder = null;
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
@@ -63,13 +66,55 @@ public class TensorRemoveUpdateReader {
expectArrayStart(buffer.currentToken());
int nesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
- builder.cell(readTensorAddress(buffer, type, originalType), 1.0);
+ if (builder == null) {
+ var typeAndAddress = readFirstTensorAddress(buffer, sparseType, originalType);
+ builder = Tensor.Builder.of(typeAndAddress.getFirst());
+ builder.cell(typeAndAddress.getSecond(), 1.0);
+ } else {
+ builder.cell(readTensorAddress(buffer, builder.type(), originalType), 1.0);
+ }
}
expectCompositeEnd(buffer.currentToken());
}
}
expectObjectEnd(buffer.currentToken());
- return builder.build();
+ return (builder != null) ? builder.build() : Tensor.Builder.of(sparseType).build();
+ }
+
+ /**
+ * Reads the first raw tensor address from the given buffer and resolves and returns the tensor type and tensor address based on this.
+ * The resulting tensor type contains a subset or all of the dimensions from the given sparseType.
+ */
+ private static Pair<TensorType, TensorAddress> readFirstTensorAddress(TokenBuffer buffer, TensorType sparseType, TensorType originalType) {
+ var typeBuilder = new TensorType.Builder(sparseType.valueType());
+ var rawAddress = new HashMap<String, String>();
+ expectObjectStart(buffer.currentToken());
+ int initNesting = buffer.nesting();
+ for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
+ var elem = readRawElement(buffer, sparseType, originalType);
+ var dimension = sparseType.dimension(elem.getFirst());
+ if (dimension.isPresent()) {
+ typeBuilder.dimension(dimension.get());
+ rawAddress.put(elem.getFirst(), elem.getSecond());
+ } else {
+ throw new IllegalArgumentException(originalType + " does not contain dimension '" + elem.getFirst() + "'");
+ }
+ }
+ expectObjectEnd(buffer.currentToken());
+ var type = typeBuilder.build();
+ var builder = new TensorAddress.Builder(type);
+ rawAddress.forEach((dimension, label) -> builder.add(dimension, label));
+ return new Pair<>(type, builder.build());
+ }
+
+ private static Pair<String, String> readRawElement(TokenBuffer buffer, TensorType type, TensorType originalType) {
+ String dimension = buffer.currentName();
+ if (type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) {
+ throw new IllegalArgumentException("Indexed dimension address '" + dimension +
+ "' should not be specified in remove update");
+ }
+ String label = buffer.currentText();
+ return new Pair<>(dimension, label);
}
private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) {
@@ -77,13 +122,8 @@ public class TensorRemoveUpdateReader {
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
- String dimension = buffer.currentName();
- if ( type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) {
- throw new IllegalArgumentException("Indexed dimension address '" + dimension +
- "' should not be specified in remove update");
- }
- String label = buffer.currentText();
- builder.add(dimension, label);
+ var elem = readRawElement(buffer, type, originalType);
+ builder.add(elem.getFirst(), elem.getSecond());
}
expectObjectEnd(buffer.currentToken());
return builder.build();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java
index cac05fb7879..92b3b566b85 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java
@@ -243,7 +243,7 @@ public class VespaDocumentDeserializer6 extends BufferSerializer implements Docu
int encodedTensorLength = buf.getInt1_4Bytes();
if (encodedTensorLength > 0) {
byte[] encodedTensor = getBytes(null, encodedTensorLength);
- value.assign(TypedBinaryFormat.decode(Optional.of(value.getDataType().getTensorType()),
+ value.assign(TypedBinaryFormat.decode(value.getTensorType(),
GrowableByteBuffer.wrap(encodedTensor)));
} else {
value.clear();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index 58c50f047f9..e7f1525ff81 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -61,10 +61,12 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
}
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
- TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType);
- TensorFieldValue tensor = new TensorFieldValue(convertedType);
+ TensorFieldValue tensor = new TensorFieldValue();
tensor.deserialize(this);
- return new TensorRemoveUpdate(tensor);
+ var result = new TensorRemoveUpdate(tensor);
+ result.verifyCompatibleType(tensorType);
+ return result;
}
+
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index 981120af145..a300565391f 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -12,9 +12,10 @@ import com.yahoo.tensor.TensorType;
import java.util.Objects;
/**
- * An update used to remove cells from a sparse tensor (has only mapped dimensions).
+ * An update used to remove cells from a sparse tensor or dense sub-spaces from a mixed tensor.
*
- * The cells to remove are contained in a sparse tensor where cell values are set to 1.0
+ * The specification of which cells to remove contains addresses using a subset or all of the sparse dimensions of the tensor type.
+ * This is represented as a sparse tensor where cell values are set to 1.0.
*/
public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
@@ -23,17 +24,20 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
public TensorRemoveUpdate(TensorFieldValue value) {
super(ValueUpdateClassID.TENSORREMOVE);
this.tensor = value;
- verifyCompatibleType();
- }
-
- private void verifyCompatibleType() {
- if ( ! tensor.getTensor().isPresent()) {
+ if (!tensor.getTensor().isPresent()) {
throw new IllegalArgumentException("Tensor must be present in remove update");
}
- TensorType tensorType = tensor.getTensor().get().type();
- TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType());
- if ( ! tensorType.equals(expectedType)) {
- throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'");
+ verifyCompatibleType(tensor.getTensorType().get());
+ }
+
+ public void verifyCompatibleType(TensorType originalType) {
+ TensorType sparseType = extractSparseDimensions(originalType);
+ TensorType thisType = tensor.getTensorType().get();
+ for (var dim : thisType.dimensions()) {
+ if (sparseType.dimension(dim.name()).isEmpty()) {
+ throw new IllegalArgumentException("Unexpected type '" + thisType + "' in remove update. "
+ + "Expected dimensions to be a subset of '" + sparseType + "'");
+ }
}
}
@@ -63,6 +67,7 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
Tensor update = tensor.getTensor().get();
+ // TODO: handle the case where this tensor only contains a subset of the sparse dimensions of the input tensor.
Tensor result = old.remove(update.cells().keySet());
return new TensorFieldValue(result);
}
@@ -102,5 +107,4 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return builder.build();
}
-
}
diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
index 4e8fa427e7d..1772a410a36 100644
--- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
+++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
@@ -436,6 +436,25 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_remove_update_with_not_fully_specified_address() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'sparse_tensor': {",
+ " 'remove': {",
+ " 'addresses': [",
+ " {'y':'0'},",
+ " {'y':'2'}",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void reference_field_id_can_be_update_assigned_non_empty_id() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 7fc43656d55..da9ab4ea7bf 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -168,6 +168,8 @@ public class JsonReaderTestCase {
new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build())));
x.addField(new Field("mixed_tensor",
new TensorDataType(new TensorType.Builder().mapped("x").indexed("y", 3).build())));
+ x.addField(new Field("mixed_tensor_adv",
+ new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").mapped("z").indexed("a", 3).build())));
types.registerDocumentType(x);
}
{
@@ -1685,6 +1687,24 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_remove_update_on_sparse_tensor_with_not_fully_specified_address() {
+ assertTensorRemoveUpdate("{{y:b}:1.0,{y:d}:1.0}", "sparse_tensor",
+ inputJson("{",
+ " 'addresses': [",
+ " { 'y': 'b' },",
+ " { 'y': 'd' } ]}"));
+ }
+
+ @Test
+ public void tensor_remove_update_on_mixed_tensor_with_not_fully_specified_address() {
+ assertTensorRemoveUpdate("{{x:1,z:a}:1.0,{x:2,z:b}:1.0}", "mixed_tensor_adv",
+ inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1', 'z': 'a' },",
+ " { 'x': '2', 'z': 'b' } ]}"));
+ }
+
+ @Test
public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() {
illegalTensorRemoveUpdate("Error in 'mixed_tensor': Indexed dimension address 'y' should not be specified in remove update",
"mixed_tensor",
@@ -1703,12 +1723,19 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_remove_update_on_not_fully_specified_cell_throws() {
- illegalTensorRemoveUpdate("Error in 'sparse_tensor': Missing a label for dimension y for tensor(x{},y{})",
- "sparse_tensor",
- "{",
- " 'addresses': [",
- " { 'x': 'a' } ]}");
+ public void tensor_remove_update_with_stray_dimension_throws() {
+ illegalTensorRemoveUpdate("Error in 'sparse_tensor': tensor(x{},y{}) does not contain dimension 'foo'",
+ "sparse_tensor",
+ "{",
+ " 'addresses': [",
+ " { 'x': 'a', 'foo': 'b' } ]}");
+
+ illegalTensorRemoveUpdate("Error in 'sparse_tensor': tensor(x{}) does not contain dimension 'foo'",
+ "sparse_tensor",
+ "{",
+ " 'addresses': [",
+ " { 'x': 'c' },",
+ " { 'x': 'a', 'foo': 'b' } ]}");
}
@Test
diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
index 3a005e858c8..86f07db1b2d 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
@@ -3,9 +3,12 @@ package com.yahoo.document.update;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.yolean.Exceptions;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
public class TensorRemoveUpdateTest {
@@ -22,4 +25,26 @@ public class TensorRemoveUpdateTest {
assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get());
}
+ @Test
+ public void verify_compatible_type_throws_on_mismatch() {
+ // Contains an indexed dimension, which is not allowed.
+ illegalTensorRemoveUpdate("tensor(x{},y[1])", "{{x:a,y:0}:1}", "tensor(x{},y[1])",
+ "Unexpected type 'tensor(x{},y[1])' in remove update. Expected dimensions to be a subset of 'tensor(x{})'");
+
+ // Sparse dimension is not found in the original type.
+ illegalTensorRemoveUpdate("tensor(y{})", "{{y:a}:1}", "tensor(x{},z[2])",
+ "Unexpected type 'tensor(y{})' in remove update. Expected dimensions to be a subset of 'tensor(x{})'");
+ }
+
+ private void illegalTensorRemoveUpdate(String updateType, String updateTensor, String originalType, String expectedMessage) {
+ try {
+ var value = new TensorFieldValue(Tensor.from(updateType, updateTensor));
+ new TensorRemoveUpdate(value).verifyCompatibleType(TensorType.fromSpec(originalType));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals(expectedMessage, Exceptions.toMessageString(expected));
+ }
+ }
+
}
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 5fd62957f65..b88a0437fc2 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -1031,6 +1031,13 @@ TEST(DocumentUpdateTest, tensor_remove_update_can_be_roundtrip_serialized)
f.assertRoundtripSerialize(TensorRemoveUpdate(f.makeBaselineTensor()));
}
+TEST(DocumentUpdateTest, tensor_remove_update_with_not_fully_specified_address_can_be_roundtrip_serialized)
+{
+ TensorUpdateFixture f("sparse_xy_tensor");
+ TensorDataType type(ValueType::from_spec("tensor(y{})"));
+ f.assertRoundtripSerialize(TensorRemoveUpdate(
+ makeTensorFieldValue(TensorSpec("tensor(y{})").add({{"y", "a"}}, 1), type)));
+}
TEST(DocumentUpdateTest, tensor_remove_update_on_float_tensor_can_be_roundtrip_serialized)
{
@@ -1087,7 +1094,7 @@ TEST(DocumentUpdateTest, tensor_remove_update_throws_if_address_tensor_is_not_sp
auto addressTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense address tensor
ASSERT_THROW(
f.assertRoundtripSerialize(TensorRemoveUpdate(std::move(addressTensor))),
- document::WrongTensorTypeException);
+ vespalib::IllegalStateException);
}
TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_sparse)
diff --git a/document/src/vespa/document/base/testdocrepo.cpp b/document/src/vespa/document/base/testdocrepo.cpp
index 58d5a30ec35..24625c6f667 100644
--- a/document/src/vespa/document/base/testdocrepo.cpp
+++ b/document/src/vespa/document/base/testdocrepo.cpp
@@ -53,6 +53,7 @@ DocumenttypesConfig TestDocRepo::getDefaultConfig() {
.addField("rawarray", Array(DataType::T_RAW))
.addField("structarray", structarray_id)
.addTensorField("sparse_tensor", "tensor(x{})")
+ .addTensorField("sparse_xy_tensor", "tensor(x{},y{})")
.addTensorField("sparse_float_tensor", "tensor<float>(x{})")
.addTensorField("dense_tensor", "tensor(x[2])"));
builder.document(type2_id, "testdoctype2",
diff --git a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp
index 6ec9c52281f..eaa5a484ad1 100644
--- a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp
+++ b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp
@@ -355,10 +355,15 @@ void VespaDocumentDeserializer::read(WeightedSetFieldValue &value) {
}
}
-
void
VespaDocumentDeserializer::read(TensorFieldValue &value)
{
+ value.assignDeserialized(readTensor());
+}
+
+std::unique_ptr<vespalib::eval::Value>
+VespaDocumentDeserializer::readTensor()
+{
size_t length = _stream.getInt1_4Bytes();
if (length > _stream.size()) {
throw DeserializeException(vespalib::make_string("Stream failed size(%zu), needed(%zu) to deserialize tensor field value", _stream.size(), length),
@@ -372,8 +377,8 @@ VespaDocumentDeserializer::read(TensorFieldValue &value)
throw DeserializeException("Leftover bytes deserializing tensor field value.", VESPA_STRLOC);
}
}
- value.assignDeserialized(std::move(tensor));
_stream.adjustReadPos(length);
+ return tensor;
}
void VespaDocumentDeserializer::read(ReferenceFieldValue& value) {
diff --git a/document/src/vespa/document/serialization/vespadocumentdeserializer.h b/document/src/vespa/document/serialization/vespadocumentdeserializer.h
index e6b490e1075..6792914d9da 100644
--- a/document/src/vespa/document/serialization/vespadocumentdeserializer.h
+++ b/document/src/vespa/document/serialization/vespadocumentdeserializer.h
@@ -7,6 +7,7 @@
#include <memory>
namespace vespalib { class nbostream; }
+namespace vespalib::eval { class Value; }
namespace document {
class DocumentId;
@@ -78,6 +79,7 @@ public:
void readStructNoReset(StructFieldValue &value);
void read(WeightedSetFieldValue &value);
void read(TensorFieldValue &value);
+ std::unique_ptr<vespalib::eval::Value> readTensor();
void read(ReferenceFieldValue& value);
};
} // namespace document
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index 5d85b8956fa..688f9cf5399 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -20,6 +20,7 @@
using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
using vespalib::make_string;
+using vespalib::eval::Value;
using vespalib::eval::ValueType;
using vespalib::eval::EngineOrFactory;
using vespalib::tensor::TensorPartialUpdate;
@@ -157,38 +158,47 @@ TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &in
namespace {
void
-verifyAddressTensorIsSparse(const vespalib::eval::Value *addressTensor)
+verifyAddressTensorIsSparse(const Value *addressTensor)
{
if (addressTensor == nullptr) {
- return;
+ throw IllegalStateException("Address tensor is not set", VESPA_STRLOC);
}
auto engine = EngineOrFactory::get();
if (TensorPartialUpdate::check_suitably_sparse(*addressTensor, engine)) {
return;
}
- vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'",
- addressTensor->type().to_spec().c_str());
+ auto err = make_string("Expected address tensor to be sparse, but has type '%s'",
+ addressTensor->type().to_spec().c_str());
throw IllegalStateException(err, VESPA_STRLOC);
}
+void
+verify_tensor_type_dimensions_are_subset_of(const ValueType& lhs_type,
+ const ValueType& rhs_type)
+{
+ for (const auto& dim : lhs_type.dimensions()) {
+ if (rhs_type.dimension_index(dim.name) == ValueType::Dimension::npos) {
+ auto err = make_string("Unexpected type '%s' for address tensor. "
+ "Expected dimensions to be a subset of '%s'",
+ lhs_type.to_spec().c_str(), rhs_type.to_spec().c_str());
+ throw IllegalStateException(err, VESPA_STRLOC);
+ }
+ }
+}
}
void
TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream)
{
- _tensorType = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type));
- auto tensor = _tensorType->createFieldValue();
- if (tensor->inherits(TensorFieldValue::classId)) {
- _tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
- } else {
- vespalib::string err = make_string("Expected tensor field value, got a '%s' field value",
- tensor->getClass().name());
- throw IllegalStateException(err, VESPA_STRLOC);
- }
VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion());
- deserializer.read(*_tensor);
- verifyAddressTensorIsSparse(_tensor->getAsTensorPtr());
+ auto tensor = deserializer.readTensor();
+ verifyAddressTensorIsSparse(tensor.get());
+ auto compatible_type = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type));
+ verify_tensor_type_dimensions_are_subset_of(tensor->type(), compatible_type->getTensorType());
+ _tensorType = std::make_unique<const TensorDataType>(tensor->type());
+ _tensor = std::make_unique<TensorFieldValue>(*_tensorType);
+ _tensor->assignDeserialized(std::move(tensor));
}
TensorRemoveUpdate *