summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
Diffstat (limited to 'document')
-rw-r--r--document/src/tests/documentupdatetestcase.cpp15
-rw-r--r--document/src/vespa/document/base/testdocrepo.cpp1
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp2
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp2
4 files changed, 18 insertions, 2 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index b8a0101a782..5543cb48ba4 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -988,6 +988,13 @@ TEST(DocumentUpdateTest, tensor_remove_update_can_be_roundtrip_serialized)
f.assertRoundtripSerialize(TensorRemoveUpdate(f.makeBaselineTensor()));
}
+
+TEST(DocumentUpdateTest, tensor_remove_update_on_float_tensor_can_be_roundtrip_serialized)
+{
+ TensorUpdateFixture f("sparse_float_tensor");
+ f.assertRoundtripSerialize(TensorRemoveUpdate(f.makeBaselineTensor()));
+}
+
TEST(DocumentUpdateTest, tensor_modify_update_can_be_roundtrip_serialized)
{
TensorUpdateFixture f;
@@ -996,6 +1003,14 @@ TEST(DocumentUpdateTest, tensor_modify_update_can_be_roundtrip_serialized)
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor()));
}
+TEST(DocumentUpdateTest, tensor_modify_update_on_float_tensor_can_be_roundtrip_serialized)
+{
+ TensorUpdateFixture f("sparse_float_tensor");
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor()));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor()));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor()));
+}
+
TEST(DocumentUpdateTest, tensor_modify_update_on_dense_tensor_can_be_roundtrip_serialized)
{
TensorUpdateFixture f("dense_tensor");
diff --git a/document/src/vespa/document/base/testdocrepo.cpp b/document/src/vespa/document/base/testdocrepo.cpp
index 68a58ea1a86..58d5a30ec35 100644
--- a/document/src/vespa/document/base/testdocrepo.cpp
+++ b/document/src/vespa/document/base/testdocrepo.cpp
@@ -53,6 +53,7 @@ DocumenttypesConfig TestDocRepo::getDefaultConfig() {
.addField("rawarray", Array(DataType::T_RAW))
.addField("structarray", structarray_id)
.addTensorField("sparse_tensor", "tensor(x{})")
+ .addTensorField("sparse_float_tensor", "tensor<float>(x{})")
.addTensorField("dense_tensor", "tensor(x[2])"));
builder.document(type2_id, "testdoctype2",
Struct("testdoctype2.header")
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index 37842b13cf4..0d821de8922 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -76,7 +76,7 @@ convertToCompatibleType(const TensorDataType &tensorType)
for (const auto &dim : tensorType.getTensorType().dimensions()) {
list.emplace_back(dim.name);
}
- return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list)));
+ return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list), tensorType.getTensorType().cell_type()));
}
}
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index 24aba4ece5a..f3bf7da7a0b 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -32,7 +32,7 @@ convertToCompatibleType(const TensorDataType &tensorType)
list.emplace_back(dim.name);
}
}
- return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list)));
+ return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list), tensorType.getTensorType().cell_type()));
}
}