diff options
author | Geir Storli <geirst@verizonmedia.com> | 2020-11-17 15:57:17 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-11-17 15:57:17 +0100 |
commit | 2571701816682b8d5989dc49bac7c5441feca213 (patch) | |
tree | b8ecb76cbd5bdbb96f68420e2077d6a0ce398d58 | |
parent | 8a4e20a2542e5d9407ca474e2a1e30902bc4158b (diff) | |
parent | cf02c8777d8bff26b2f1cc73e342c38945b7c94c (diff) |
Merge pull request #15368 from vespa-engine/geirst/extend-tensor-remove-update
Extend tensor remove update to handle not fully specified addresses
16 files changed, 326 insertions, 71 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json index c9191aa2fdb..b119f9991b3 100644 --- a/document/abi-spec.json +++ b/document/abi-spec.json @@ -3150,9 +3150,11 @@ "public" ], "methods": [ + "public void <init>()", "public void <init>(com.yahoo.tensor.TensorType)", "public void <init>(com.yahoo.tensor.Tensor)", "public java.util.Optional getTensor()", + "public java.util.Optional getTensorType()", "public com.yahoo.document.TensorDataType getDataType()", "public java.lang.String toString()", "public void printXml(com.yahoo.document.serialization.XmlStream)", @@ -4379,6 +4381,7 @@ ], "methods": [ "public void <init>(com.yahoo.document.datatypes.TensorFieldValue)", + "public void verifyCompatibleType(com.yahoo.tensor.TensorType)", "protected void checkCompatibility(com.yahoo.document.DataType)", "public void serialize(com.yahoo.document.serialization.DocumentUpdateWriter, com.yahoo.document.DataType)", "public com.yahoo.document.datatypes.FieldValue applyTo(com.yahoo.document.datatypes.FieldValue)", diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java index 2c6a556c652..8e7dbd3512a 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java @@ -20,17 +20,27 @@ public class TensorFieldValue extends FieldValue { private Optional<Tensor> tensor; - private final TensorDataType dataType; + private Optional<TensorDataType> dataType; + + /** + * Create an empty tensor field value where the tensor type is not yet known. + * + * The tensor (and tensor type) can later be assigned with assignTensor(). + */ + public TensorFieldValue() { + this.dataType = Optional.empty(); + this.tensor = Optional.empty(); + } - /** Create an empty tensor field value */ + /** Create an empty tensor field value for the given tensor type */ public TensorFieldValue(TensorType type) { - this.dataType = new TensorDataType(type); + this.dataType = Optional.of(new TensorDataType(type)); this.tensor = Optional.empty(); } /** Create a tensor field value containing the given tensor */ public TensorFieldValue(Tensor tensor) { - this.dataType = new TensorDataType(tensor.type()); + this.dataType = Optional.of(new TensorDataType(tensor.type())); this.tensor = Optional.of(tensor); } @@ -38,9 +48,13 @@ public class TensorFieldValue extends FieldValue { return tensor; } + public Optional<TensorType> getTensorType() { + return dataType.isPresent() ? Optional.of(dataType.get().getTensorType()) : Optional.empty(); + } + @Override public TensorDataType getDataType() { - return dataType; + return dataType.get(); } @Override @@ -76,10 +90,22 @@ public class TensorFieldValue extends FieldValue { } } + /** + * Assigns the given tensor to this field value. + * + * The tensor type is also set from the given tensor if it was not set before. + */ public void assignTensor(Optional<Tensor> tensor) { - if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType())) - throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() + - " to field of type " + dataType.getTensorType()); + if (tensor.isPresent()) { + if (getTensorType().isPresent() && + !tensor.get().type().isAssignableTo(getTensorType().get())) { + throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() + + " to field of type " + getTensorType().get()); + } + if (getTensorType().isEmpty()) { + this.dataType = Optional.of(new TensorDataType(tensor.get().type())); + } + } this.tensor = tensor; } @@ -99,7 +125,7 @@ public class TensorFieldValue extends FieldValue { if ( ! (o instanceof TensorFieldValue)) return false; TensorFieldValue other = (TensorFieldValue)o; - if ( ! dataType.getTensorType().equals(other.dataType.getTensorType())) return false; + if ( ! getTensorType().equals(other.getTensorType())) return false; if ( ! tensor.equals(other.tensor)) return false; return true; } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java index 91c275b6da0..cffc85777dc 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java @@ -1,6 +1,7 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.json.readers; +import com.yahoo.collections.Pair; import com.yahoo.document.Field; import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; @@ -10,6 +11,8 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.util.HashMap; + import static com.yahoo.document.json.readers.JsonParserHelpers.expectArrayStart; import static com.yahoo.document.json.readers.JsonParserHelpers.expectCompositeEnd; import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd; @@ -28,8 +31,8 @@ public class TensorRemoveUpdateReader { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType originalType = tensorDataType.getTensorType(); - TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType); - Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType); + TensorType sparseType = TensorRemoveUpdate.extractSparseDimensions(originalType); + Tensor tensor = readRemoveUpdateTensor(buffer, sparseType, originalType); expectAddressesAreNonEmpty(field, tensor); return new TensorRemoveUpdate(new TensorFieldValue(tensor)); @@ -54,8 +57,8 @@ public class TensorRemoveUpdateReader { /** * Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0 */ - private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) { - Tensor.Builder builder = Tensor.Builder.of(type); + private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType sparseType, TensorType originalType) { + Tensor.Builder builder = null; expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { @@ -63,13 +66,55 @@ public class TensorRemoveUpdateReader { expectArrayStart(buffer.currentToken()); int nesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) { - builder.cell(readTensorAddress(buffer, type, originalType), 1.0); + if (builder == null) { + var typeAndAddress = readFirstTensorAddress(buffer, sparseType, originalType); + builder = Tensor.Builder.of(typeAndAddress.getFirst()); + builder.cell(typeAndAddress.getSecond(), 1.0); + } else { + builder.cell(readTensorAddress(buffer, builder.type(), originalType), 1.0); + } } expectCompositeEnd(buffer.currentToken()); } } expectObjectEnd(buffer.currentToken()); - return builder.build(); + return (builder != null) ? builder.build() : Tensor.Builder.of(sparseType).build(); + } + + /** + * Reads the first raw tensor address from the given buffer and resolves and returns the tensor type and tensor address based on this. + * The resulting tensor type contains a subset or all of the dimensions from the given sparseType. + */ + private static Pair<TensorType, TensorAddress> readFirstTensorAddress(TokenBuffer buffer, TensorType sparseType, TensorType originalType) { + var typeBuilder = new TensorType.Builder(sparseType.valueType()); + var rawAddress = new HashMap<String, String>(); + expectObjectStart(buffer.currentToken()); + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { + var elem = readRawElement(buffer, sparseType, originalType); + var dimension = sparseType.dimension(elem.getFirst()); + if (dimension.isPresent()) { + typeBuilder.dimension(dimension.get()); + rawAddress.put(elem.getFirst(), elem.getSecond()); + } else { + throw new IllegalArgumentException(originalType + " does not contain dimension '" + elem.getFirst() + "'"); + } + } + expectObjectEnd(buffer.currentToken()); + var type = typeBuilder.build(); + var builder = new TensorAddress.Builder(type); + rawAddress.forEach((dimension, label) -> builder.add(dimension, label)); + return new Pair<>(type, builder.build()); + } + + private static Pair<String, String> readRawElement(TokenBuffer buffer, TensorType type, TensorType originalType) { + String dimension = buffer.currentName(); + if (type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) { + throw new IllegalArgumentException("Indexed dimension address '" + dimension + + "' should not be specified in remove update"); + } + String label = buffer.currentText(); + return new Pair<>(dimension, label); } private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) { @@ -77,13 +122,8 @@ public class TensorRemoveUpdateReader { expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { - String dimension = buffer.currentName(); - if ( type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) { - throw new IllegalArgumentException("Indexed dimension address '" + dimension + - "' should not be specified in remove update"); - } - String label = buffer.currentText(); - builder.add(dimension, label); + var elem = readRawElement(buffer, type, originalType); + builder.add(elem.getFirst(), elem.getSecond()); } expectObjectEnd(buffer.currentToken()); return builder.build(); diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java index cac05fb7879..92b3b566b85 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java @@ -243,7 +243,7 @@ public class VespaDocumentDeserializer6 extends BufferSerializer implements Docu int encodedTensorLength = buf.getInt1_4Bytes(); if (encodedTensorLength > 0) { byte[] encodedTensor = getBytes(null, encodedTensorLength); - value.assign(TypedBinaryFormat.decode(Optional.of(value.getDataType().getTensorType()), + value.assign(TypedBinaryFormat.decode(value.getTensorType(), GrowableByteBuffer.wrap(encodedTensor))); } else { value.clear(); diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java index 58c50f047f9..e7f1525ff81 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java @@ -61,10 +61,12 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { } TensorDataType tensorDataType = (TensorDataType)type; TensorType tensorType = tensorDataType.getTensorType(); - TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType); - TensorFieldValue tensor = new TensorFieldValue(convertedType); + TensorFieldValue tensor = new TensorFieldValue(); tensor.deserialize(this); - return new TensorRemoveUpdate(tensor); + var result = new TensorRemoveUpdate(tensor); + result.verifyCompatibleType(tensorType); + return result; } + } diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index 981120af145..a300565391f 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -12,9 +12,10 @@ import com.yahoo.tensor.TensorType; import java.util.Objects; /** - * An update used to remove cells from a sparse tensor (has only mapped dimensions). + * An update used to remove cells from a sparse tensor or dense sub-spaces from a mixed tensor. * - * The cells to remove are contained in a sparse tensor where cell values are set to 1.0 + * The specification of which cells to remove contains addresses using a subset or all of the sparse dimensions of the tensor type. + * This is represented as a sparse tensor where cell values are set to 1.0. */ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { @@ -23,17 +24,20 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { public TensorRemoveUpdate(TensorFieldValue value) { super(ValueUpdateClassID.TENSORREMOVE); this.tensor = value; - verifyCompatibleType(); - } - - private void verifyCompatibleType() { - if ( ! tensor.getTensor().isPresent()) { + if (!tensor.getTensor().isPresent()) { throw new IllegalArgumentException("Tensor must be present in remove update"); } - TensorType tensorType = tensor.getTensor().get().type(); - TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType()); - if ( ! tensorType.equals(expectedType)) { - throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'"); + verifyCompatibleType(tensor.getTensorType().get()); + } + + public void verifyCompatibleType(TensorType originalType) { + TensorType sparseType = extractSparseDimensions(originalType); + TensorType thisType = tensor.getTensorType().get(); + for (var dim : thisType.dimensions()) { + if (sparseType.dimension(dim.name()).isEmpty()) { + throw new IllegalArgumentException("Unexpected type '" + thisType + "' in remove update. " + + "Expected dimensions to be a subset of '" + sparseType + "'"); + } } } @@ -63,6 +67,7 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); Tensor update = tensor.getTensor().get(); + // TODO: handle the case where this tensor only contains a subset of the sparse dimensions of the input tensor. Tensor result = old.remove(update.cells().keySet()); return new TensorFieldValue(result); } @@ -102,5 +107,4 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { return builder.build(); } - } diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java index 4e8fa427e7d..1772a410a36 100644 --- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java +++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java @@ -436,6 +436,25 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_remove_update_with_not_fully_specified_address() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'sparse_tensor': {", + " 'remove': {", + " 'addresses': [", + " {'y':'0'},", + " {'y':'2'}", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void reference_field_id_can_be_update_assigned_non_empty_id() { roundtripSerializeJsonAndMatch(inputJson( "{", diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 7fc43656d55..da9ab4ea7bf 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -168,6 +168,8 @@ public class JsonReaderTestCase { new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build()))); x.addField(new Field("mixed_tensor", new TensorDataType(new TensorType.Builder().mapped("x").indexed("y", 3).build()))); + x.addField(new Field("mixed_tensor_adv", + new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").mapped("z").indexed("a", 3).build()))); types.registerDocumentType(x); } { @@ -1685,6 +1687,24 @@ public class JsonReaderTestCase { } @Test + public void tensor_remove_update_on_sparse_tensor_with_not_fully_specified_address() { + assertTensorRemoveUpdate("{{y:b}:1.0,{y:d}:1.0}", "sparse_tensor", + inputJson("{", + " 'addresses': [", + " { 'y': 'b' },", + " { 'y': 'd' } ]}")); + } + + @Test + public void tensor_remove_update_on_mixed_tensor_with_not_fully_specified_address() { + assertTensorRemoveUpdate("{{x:1,z:a}:1.0,{x:2,z:b}:1.0}", "mixed_tensor_adv", + inputJson("{", + " 'addresses': [", + " { 'x': '1', 'z': 'a' },", + " { 'x': '2', 'z': 'b' } ]}")); + } + + @Test public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() { illegalTensorRemoveUpdate("Error in 'mixed_tensor': Indexed dimension address 'y' should not be specified in remove update", "mixed_tensor", @@ -1703,12 +1723,19 @@ public class JsonReaderTestCase { } @Test - public void tensor_remove_update_on_not_fully_specified_cell_throws() { - illegalTensorRemoveUpdate("Error in 'sparse_tensor': Missing a label for dimension y for tensor(x{},y{})", - "sparse_tensor", - "{", - " 'addresses': [", - " { 'x': 'a' } ]}"); + public void tensor_remove_update_with_stray_dimension_throws() { + illegalTensorRemoveUpdate("Error in 'sparse_tensor': tensor(x{},y{}) does not contain dimension 'foo'", + "sparse_tensor", + "{", + " 'addresses': [", + " { 'x': 'a', 'foo': 'b' } ]}"); + + illegalTensorRemoveUpdate("Error in 'sparse_tensor': tensor(x{}) does not contain dimension 'foo'", + "sparse_tensor", + "{", + " 'addresses': [", + " { 'x': 'c' },", + " { 'x': 'a', 'foo': 'b' } ]}"); } @Test diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java index 3a005e858c8..86f07db1b2d 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java @@ -3,9 +3,12 @@ package com.yahoo.document.update; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.yolean.Exceptions; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; public class TensorRemoveUpdateTest { @@ -22,4 +25,26 @@ public class TensorRemoveUpdateTest { assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); } + @Test + public void verify_compatible_type_throws_on_mismatch() { + // Contains an indexed dimension, which is not allowed. + illegalTensorRemoveUpdate("tensor(x{},y[1])", "{{x:a,y:0}:1}", "tensor(x{},y[1])", + "Unexpected type 'tensor(x{},y[1])' in remove update. Expected dimensions to be a subset of 'tensor(x{})'"); + + // Sparse dimension is not found in the original type. + illegalTensorRemoveUpdate("tensor(y{})", "{{y:a}:1}", "tensor(x{},z[2])", + "Unexpected type 'tensor(y{})' in remove update. Expected dimensions to be a subset of 'tensor(x{})'"); + } + + private void illegalTensorRemoveUpdate(String updateType, String updateTensor, String originalType, String expectedMessage) { + try { + var value = new TensorFieldValue(Tensor.from(updateType, updateTensor)); + new TensorRemoveUpdate(value).verifyCompatibleType(TensorType.fromSpec(originalType)); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals(expectedMessage, Exceptions.toMessageString(expected)); + } + } + } diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 5fd62957f65..b88a0437fc2 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -1031,6 +1031,13 @@ TEST(DocumentUpdateTest, tensor_remove_update_can_be_roundtrip_serialized) f.assertRoundtripSerialize(TensorRemoveUpdate(f.makeBaselineTensor())); } +TEST(DocumentUpdateTest, tensor_remove_update_with_not_fully_specified_address_can_be_roundtrip_serialized) +{ + TensorUpdateFixture f("sparse_xy_tensor"); + TensorDataType type(ValueType::from_spec("tensor(y{})")); + f.assertRoundtripSerialize(TensorRemoveUpdate( + makeTensorFieldValue(TensorSpec("tensor(y{})").add({{"y", "a"}}, 1), type))); +} TEST(DocumentUpdateTest, tensor_remove_update_on_float_tensor_can_be_roundtrip_serialized) { @@ -1087,7 +1094,7 @@ TEST(DocumentUpdateTest, tensor_remove_update_throws_if_address_tensor_is_not_sp auto addressTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense address tensor ASSERT_THROW( f.assertRoundtripSerialize(TensorRemoveUpdate(std::move(addressTensor))), - document::WrongTensorTypeException); + vespalib::IllegalStateException); } TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_sparse) diff --git a/document/src/vespa/document/base/testdocrepo.cpp b/document/src/vespa/document/base/testdocrepo.cpp index 58d5a30ec35..24625c6f667 100644 --- a/document/src/vespa/document/base/testdocrepo.cpp +++ b/document/src/vespa/document/base/testdocrepo.cpp @@ -53,6 +53,7 @@ DocumenttypesConfig TestDocRepo::getDefaultConfig() { .addField("rawarray", Array(DataType::T_RAW)) .addField("structarray", structarray_id) .addTensorField("sparse_tensor", "tensor(x{})") + .addTensorField("sparse_xy_tensor", "tensor(x{},y{})") .addTensorField("sparse_float_tensor", "tensor<float>(x{})") .addTensorField("dense_tensor", "tensor(x[2])")); builder.document(type2_id, "testdoctype2", diff --git a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp index 6ec9c52281f..eaa5a484ad1 100644 --- a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp +++ b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp @@ -355,10 +355,15 @@ void VespaDocumentDeserializer::read(WeightedSetFieldValue &value) { } } - void VespaDocumentDeserializer::read(TensorFieldValue &value) { + value.assignDeserialized(readTensor()); +} + +std::unique_ptr<vespalib::eval::Value> +VespaDocumentDeserializer::readTensor() +{ size_t length = _stream.getInt1_4Bytes(); if (length > _stream.size()) { throw DeserializeException(vespalib::make_string("Stream failed size(%zu), needed(%zu) to deserialize tensor field value", _stream.size(), length), @@ -372,8 +377,8 @@ VespaDocumentDeserializer::read(TensorFieldValue &value) throw DeserializeException("Leftover bytes deserializing tensor field value.", VESPA_STRLOC); } } - value.assignDeserialized(std::move(tensor)); _stream.adjustReadPos(length); + return tensor; } void VespaDocumentDeserializer::read(ReferenceFieldValue& value) { diff --git a/document/src/vespa/document/serialization/vespadocumentdeserializer.h b/document/src/vespa/document/serialization/vespadocumentdeserializer.h index e6b490e1075..6792914d9da 100644 --- a/document/src/vespa/document/serialization/vespadocumentdeserializer.h +++ b/document/src/vespa/document/serialization/vespadocumentdeserializer.h @@ -7,6 +7,7 @@ #include <memory> namespace vespalib { class nbostream; } +namespace vespalib::eval { class Value; } namespace document { class DocumentId; @@ -78,6 +79,7 @@ public: void readStructNoReset(StructFieldValue &value); void read(WeightedSetFieldValue &value); void read(TensorFieldValue &value); + std::unique_ptr<vespalib::eval::Value> readTensor(); void read(ReferenceFieldValue& value); }; } // namespace document diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index 5d85b8956fa..688f9cf5399 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -20,6 +20,7 @@ using vespalib::IllegalArgumentException; using vespalib::IllegalStateException; using vespalib::make_string; +using vespalib::eval::Value; using vespalib::eval::ValueType; using vespalib::eval::EngineOrFactory; using vespalib::tensor::TensorPartialUpdate; @@ -157,38 +158,47 @@ TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &in namespace { void -verifyAddressTensorIsSparse(const vespalib::eval::Value *addressTensor) +verifyAddressTensorIsSparse(const Value *addressTensor) { if (addressTensor == nullptr) { - return; + throw IllegalStateException("Address tensor is not set", VESPA_STRLOC); } auto engine = EngineOrFactory::get(); if (TensorPartialUpdate::check_suitably_sparse(*addressTensor, engine)) { return; } - vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'", - addressTensor->type().to_spec().c_str()); + auto err = make_string("Expected address tensor to be sparse, but has type '%s'", + addressTensor->type().to_spec().c_str()); throw IllegalStateException(err, VESPA_STRLOC); } +void +verify_tensor_type_dimensions_are_subset_of(const ValueType& lhs_type, + const ValueType& rhs_type) +{ + for (const auto& dim : lhs_type.dimensions()) { + if (rhs_type.dimension_index(dim.name) == ValueType::Dimension::npos) { + auto err = make_string("Unexpected type '%s' for address tensor. " + "Expected dimensions to be a subset of '%s'", + lhs_type.to_spec().c_str(), rhs_type.to_spec().c_str()); + throw IllegalStateException(err, VESPA_STRLOC); + } + } +} } void TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream) { - _tensorType = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type)); - auto tensor = _tensorType->createFieldValue(); - if (tensor->inherits(TensorFieldValue::classId)) { - _tensor.reset(static_cast<TensorFieldValue *>(tensor.release())); - } else { - vespalib::string err = make_string("Expected tensor field value, got a '%s' field value", - tensor->getClass().name()); - throw IllegalStateException(err, VESPA_STRLOC); - } VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion()); - deserializer.read(*_tensor); - verifyAddressTensorIsSparse(_tensor->getAsTensorPtr()); + auto tensor = deserializer.readTensor(); + verifyAddressTensorIsSparse(tensor.get()); + auto compatible_type = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type)); + verify_tensor_type_dimensions_are_subset_of(tensor->type(), compatible_type->getTensorType()); + _tensorType = std::make_unique<const TensorDataType>(tensor->type()); + _tensor = std::make_unique<TensorFieldValue>(*_tensorType); + _tensor->assignDeserialized(std::move(tensor)); } TensorRemoveUpdate * diff --git a/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp b/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp index 220eee0ba8f..e182fffa890 100644 --- a/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp +++ b/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp @@ -116,4 +116,28 @@ TEST(PartialRemoveTest, partial_remove_returns_nullptr_on_invalid_inputs) { } } +void +expect_partial_remove(const TensorSpec& input, const TensorSpec& remove, const TensorSpec& exp) +{ + auto act = perform_partial_remove(input, remove); + EXPECT_EQ(exp, act); +} + +TEST(PartialRemoveTest, remove_where_address_is_not_fully_specified) { + auto input = TensorSpec("tensor(x{},y{})"). + add({{"x", "a"},{"y", "c"}}, 3.0). + add({{"x", "a"},{"y", "d"}}, 5.0). + add({{"x", "b"},{"y", "c"}}, 7.0); + + expect_partial_remove(input,TensorSpec("tensor(x{})").add({{"x", "a"}}, 1.0), + TensorSpec("tensor(x{},y{})").add({{"x", "b"},{"y", "c"}}, 7.0)); + + expect_partial_remove(input, TensorSpec("tensor(y{})").add({{"y", "c"}}, 1.0), + TensorSpec("tensor(x{},y{})").add({{"x", "a"},{"y", "d"}}, 5.0)); + + expect_partial_remove(input, TensorSpec("tensor(y{})").add({{"y", "d"}}, 1.0), + TensorSpec("tensor(x{},y{})").add({{"x", "a"},{"y", "c"}}, 3.0) + .add({{"x", "b"},{"y", "c"}}, 7.0)); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/vespa/eval/tensor/partial_update.cpp b/eval/src/vespa/eval/tensor/partial_update.cpp index 014ffeb2666..fa15b2a38ae 100644 --- a/eval/src/vespa/eval/tensor/partial_update.cpp +++ b/eval/src/vespa/eval/tensor/partial_update.cpp @@ -298,31 +298,91 @@ struct PerformRemove { const ValueBuilderFactory &factory); }; +/** + * Calculates the indexes of where the mapped modifier dimensions are found in the mapped input dimensions. + * + * The modifier dimensions should be a subset or all of the input dimensions. + * An empty vector is returned on type mismatch. + */ +std::vector<size_t> +calc_mapped_dimension_indexes(const ValueType& input_type, + const ValueType& modifier_type) +{ + auto input_dims = input_type.mapped_dimensions(); + auto mod_dims = modifier_type.mapped_dimensions(); + if (mod_dims.size() > input_dims.size()) { + return {}; + } + std::vector<size_t> result(mod_dims.size()); + size_t j = 0; + for (size_t i = 0; i < mod_dims.size(); ++i) { + while ((j < input_dims.size()) && (input_dims[j] != mod_dims[i])) { + ++j; + } + if (j >= input_dims.size()) { + return {}; + } + result[i] = j; + } + return result; +} + +struct ModifierCoords { + + std::vector<const vespalib::stringref *> lookup_refs; + std::vector<size_t> lookup_view_dims; + + ModifierCoords(const SparseCoords& input_coords, + const std::vector<size_t>& input_dim_indexes, + const ValueType& modifier_type) + : lookup_refs(modifier_type.dimensions().size()), + lookup_view_dims(modifier_type.dimensions().size()) + { + assert(modifier_type.dimensions().size() == input_dim_indexes.size()); + for (size_t i = 0; i < input_dim_indexes.size(); ++i) { + // Setup the modifier dimensions to point to the matching input dimensions. + lookup_refs[i] = &input_coords.addr[input_dim_indexes[i]]; + lookup_view_dims[i] = i; + } + } + ~ModifierCoords() {} +}; + template <typename ICT> Value::UP PerformRemove::invoke(const Value &input, const Value &modifier, const ValueBuilderFactory &factory) { const ValueType &input_type = input.type(); const ValueType &modifier_type = modifier.type(); - if (input_type.mapped_dimensions() != modifier_type.dimensions()) { - LOG(error, "when removing cells from a tensor, mapped dimensions must be equal. " - "Got input type %s versus modifier type %s", - input_type.to_spec().c_str(), modifier_type.to_spec().c_str()); - return {}; - } const size_t num_mapped_in_input = input_type.count_mapped_dimensions(); if (num_mapped_in_input == 0) { - LOG(error, "cannot remove cells from a dense tensor of type %s", + LOG(error, "Cannot remove cells from a dense input tensor of type %s", input_type.to_spec().c_str()); return {}; } + if (modifier_type.count_indexed_dimensions() != 0) { + LOG(error, "Cannot remove cells using a modifier tensor of type %s", + modifier_type.to_spec().c_str()); + return {}; + } + auto input_dim_indexes = calc_mapped_dimension_indexes(input_type, modifier_type); + if (input_dim_indexes.empty()) { + LOG(error, "Tensor type mismatch when removing cells from a tensor. " + "Got input type %s versus modifier type %s", + input_type.to_spec().c_str(), modifier_type.to_spec().c_str()); + return {}; + } SparseCoords addrs(num_mapped_in_input); - auto modifier_view = modifier.index().create_view(addrs.lookup_view_dims); + ModifierCoords mod_coords(addrs, input_dim_indexes, modifier_type); + auto modifier_view = modifier.index().create_view(mod_coords.lookup_view_dims); const size_t expected_subspaces = input.index().size(); const size_t dsss = input_type.dense_subspace_size(); auto builder = factory.create_value_builder<ICT>(input_type, num_mapped_in_input, dsss, expected_subspaces); auto filter_by_modifier = [&] (const auto & lookup_refs, size_t) { - modifier_view->lookup(lookup_refs); + // The modifier dimensions are setup to point to the input dimensions address storage in ModifierCoords, + // so we don't need to use the lookup_refs argument. + (void) lookup_refs; + modifier_view->lookup(mod_coords.lookup_refs); size_t modifier_subspace_index; return !(modifier_view->next_result({}, modifier_subspace_index)); }; |