aboutsummaryrefslogtreecommitdiffstats
path: root/document/src
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-10-13 14:11:28 +0000
committerArne Juul <arnej@verizonmedia.com>2020-10-15 08:18:59 +0000
commit08393e9e14635f1c6a6c84650c25023a0db7ed0b (patch)
tree48aae1605140fc6ff7d571084f345d33a3189c62 /document/src
parent61eaea251e8cacd320ac10754ffd1513d8638043 (diff)
handle both engine- and factory-based tensors
* use EngineOrFactory::get() instead of DefaultTensorEngine::ref() * avoid direct use of DenseTensorView etc where possible * use eval::Value instead of tensor::Tensor where possible
Diffstat (limited to 'document/src')
-rw-r--r--document/src/tests/documentupdatetestcase.cpp18
-rw-r--r--document/src/tests/serialization/vespadocumentserializer_test.cpp40
-rw-r--r--document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp23
-rw-r--r--document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp22
-rw-r--r--document/src/vespa/document/fieldvalue/tensorfieldvalue.h9
-rw-r--r--document/src/vespa/document/serialization/vespadocumentdeserializer.cpp9
-rw-r--r--document/src/vespa/document/serialization/vespadocumentserializer.cpp5
-rw-r--r--document/src/vespa/document/update/tensor_add_update.cpp34
-rw-r--r--document/src/vespa/document/update/tensor_add_update.h4
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp57
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.h4
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp56
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.h4
13 files changed, 182 insertions, 103 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 9d2567e93ed..ca519a2f7d0 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -21,8 +21,9 @@
#include <vespa/document/update/tensor_remove_update.h>
#include <vespa/document/update/valueupdate.h>
#include <vespa/document/util/bytebuffer.h>
-#include <vespa/eval/tensor/default_tensor_engine.h>
-#include <vespa/eval/tensor/tensor.h>
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/test/value_compare.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/exception.h>
#include <vespa/vespalib/util/exceptions.h>
@@ -33,10 +34,10 @@
#include <unistd.h>
using namespace document::config_builder;
+
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
-using vespalib::tensor::DefaultTensorEngine;
-using vespalib::tensor::Tensor;
+using vespalib::eval::EngineOrFactory;
using vespalib::nbostream;
namespace document {
@@ -771,11 +772,12 @@ TEST(DocumentUpdateTest, testMapValueUpdate)
EXPECT_EQ(fv4->find(StringFieldValue("apple")), fv4->end());
}
-std::unique_ptr<Tensor>
+std::unique_ptr<vespalib::eval::Value>
makeTensor(const TensorSpec &spec)
{
- auto result = DefaultTensorEngine::ref().from_spec(spec);
- return std::unique_ptr<Tensor>(dynamic_cast<Tensor*>(result.release()));
+ auto result = EngineOrFactory::get().from_spec(spec);
+ EXPECT_TRUE(result->is_tensor());
+ return result;
}
std::unique_ptr<TensorFieldValue>
@@ -787,7 +789,7 @@ makeTensorFieldValue(const TensorSpec &spec, const TensorDataType &dataType)
return result;
}
-const Tensor &asTensor(const FieldValue &fieldValue) {
+const vespalib::eval::Value &asTensor(const FieldValue &fieldValue) {
auto &tensorFieldValue = dynamic_cast<const TensorFieldValue &>(fieldValue);
auto tensor = tensorFieldValue.getAsTensorPtr();
assert(tensor);
diff --git a/document/src/tests/serialization/vespadocumentserializer_test.cpp b/document/src/tests/serialization/vespadocumentserializer_test.cpp
index 02f170cd5f1..13d5e7d8405 100644
--- a/document/src/tests/serialization/vespadocumentserializer_test.cpp
+++ b/document/src/tests/serialization/vespadocumentserializer_test.cpp
@@ -36,8 +36,9 @@
#include <vespa/document/serialization/vespadocumentdeserializer.h>
#include <vespa/document/serialization/vespadocumentserializer.h>
#include <vespa/document/serialization/annotationserializer.h>
-#include <vespa/eval/tensor/tensor.h>
-#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/test/value_compare.h>
#include <vespa/vespalib/io/fileutil.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/testkit/testapp.h>
@@ -50,8 +51,7 @@ using vespalib::nbostream;
using vespalib::nbostream_longlivedbuf;
using vespalib::slime::Cursor;
using vespalib::eval::TensorSpec;
-using vespalib::tensor::Tensor;
-using vespalib::tensor::DefaultTensorEngine;
+using vespalib::eval::EngineOrFactory;
using vespalib::compression::CompressionConfig;
using namespace document;
using std::string;
@@ -771,12 +771,10 @@ TEST("Require that predicate deserialization matches Java") {
namespace
{
-Tensor::UP createTensor(const TensorSpec &spec) {
- auto value = DefaultTensorEngine::ref().from_spec(spec);
- Tensor *tensor = dynamic_cast<Tensor*>(value.get());
- ASSERT_TRUE(tensor != nullptr);
- value.release();
- return Tensor::UP(tensor);
+vespalib::eval::Value::UP createTensor(const TensorSpec &spec) {
+ auto value = EngineOrFactory::get().from_spec(spec);
+ ASSERT_TRUE(value->is_tensor());
+ return value;
}
}
@@ -836,13 +834,13 @@ void deserializeAndCheck(const string &file_name, TensorFieldValue &value) {
deserializeAndCheck(file_name, value, tensor_repo, tensor_field_name);
}
-void checkDeserialization(const string &name, std::unique_ptr<Tensor> tensor) {
+void checkDeserialization(const string &name, std::unique_ptr<vespalib::eval::Value> tensor) {
const string data_dir = TEST_PATH("../../test/resources/tensor/");
TensorDataType valueType(tensor ? tensor->type() : vespalib::eval::ValueType::error_type());
TensorFieldValue value(valueType);
if (tensor) {
- value = tensor->clone();
+ value = EngineOrFactory::get().copy(*tensor);
}
serializeToFile(value, data_dir + name + "__cpp");
deserializeAndCheck(data_dir + name + "__cpp", value);
@@ -851,7 +849,7 @@ void checkDeserialization(const string &name, std::unique_ptr<Tensor> tensor) {
TEST("Require that tensor deserialization matches Java") {
- checkDeserialization("non_existing_tensor", std::unique_ptr<Tensor>());
+ checkDeserialization("non_existing_tensor", std::unique_ptr<vespalib::eval::Value>());
checkDeserialization("empty_tensor", createTensor(TensorSpec("tensor(dimX{},dimY{})")));
checkDeserialization("multi_cell_tensor",
createTensor(TensorSpec("tensor(dimX{},dimY{})")
@@ -863,17 +861,17 @@ TEST("Require that tensor deserialization matches Java") {
struct TensorDocFixture {
const DocumentTypeRepo &_docTypeRepo;
const DocumentType *_docType;
- std::unique_ptr<Tensor> _tensor;
+ std::unique_ptr<vespalib::eval::Value> _tensor;
Document _doc;
vespalib::nbostream _blob;
TensorDocFixture(const DocumentTypeRepo &docTypeRepo,
- std::unique_ptr<Tensor> tensor);
+ std::unique_ptr<vespalib::eval::Value> tensor);
~TensorDocFixture();
};
TensorDocFixture::TensorDocFixture(const DocumentTypeRepo &docTypeRepo,
- std::unique_ptr<Tensor> tensor)
+ std::unique_ptr<vespalib::eval::Value> tensor)
: _docTypeRepo(docTypeRepo),
_docType(_docTypeRepo.getDocumentType(tensor_doc_type_id)),
_tensor(std::move(tensor)),
@@ -881,7 +879,7 @@ TensorDocFixture::TensorDocFixture(const DocumentTypeRepo &docTypeRepo,
_blob()
{
auto fv = _doc.getField(tensor_field_name).createValue();
- dynamic_cast<TensorFieldValue &>(*fv) = _tensor->clone();
+ dynamic_cast<TensorFieldValue &>(*fv) = EngineOrFactory::get().copy(*_tensor);
_doc.setValue(tensor_field_name, *fv);
_doc.serialize(_blob);
}
@@ -897,7 +895,7 @@ struct DeserializedTensorDoc
~DeserializedTensorDoc();
void setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob);
- const Tensor *getTensor() const;
+ const vespalib::eval::Value *getTensor() const;
};
DeserializedTensorDoc::DeserializedTensorDoc()
@@ -916,7 +914,7 @@ DeserializedTensorDoc::setup(const DocumentTypeRepo &docTypeRepo, const vespalib
_fieldValue = _doc->getValue(tensor_field_name);
}
-const Tensor *
+const vespalib::eval::Value *
DeserializedTensorDoc::getTensor() const
{
return dynamic_cast<const TensorFieldValue &>(*_fieldValue).getAsTensorPtr();
@@ -936,14 +934,14 @@ TEST("Require that wrong tensor type hides tensor")
DeserializedTensorDoc doc;
doc.setup(tensor_doc_repo, f._blob);
EXPECT_TRUE(doc.getTensor() != nullptr);
- EXPECT_TRUE(doc.getTensor()->equals(*f._tensor));
+ EXPECT_TRUE((*doc.getTensor()) == (*f._tensor));
doc.setup(tensor_doc_repo, f1._blob);
EXPECT_TRUE(doc.getTensor() == nullptr);
doc.setup(tensor_doc_repo1, f._blob);
EXPECT_TRUE(doc.getTensor() == nullptr);
doc.setup(tensor_doc_repo1, f1._blob);
EXPECT_TRUE(doc.getTensor() != nullptr);
- EXPECT_TRUE(doc.getTensor()->equals(*f1._tensor));
+ EXPECT_TRUE((*doc.getTensor()) == (*f1._tensor));
}
struct RefFixture {
diff --git a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp
index 9d2da9c983a..18afdb15bb8 100644
--- a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp
+++ b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp
@@ -7,9 +7,8 @@ LOG_SETUP("fieldvalue_test");
#include <vespa/document/base/exceptions.h>
#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
-#include <vespa/eval/tensor/tensor.h>
-#include <vespa/eval/tensor/types.h>
-#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/eval/value.h>
#include <vespa/eval/tensor/test/test_utils.h>
#include <vespa/vespalib/testkit/testapp.h>
@@ -18,7 +17,7 @@ using namespace document;
using namespace vespalib::tensor;
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
-using vespalib::tensor::DefaultTensorEngine;
+using vespalib::eval::EngineOrFactory;
using vespalib::tensor::test::makeTensor;
namespace
@@ -27,19 +26,17 @@ namespace
TensorDataType xSparseTensorDataType(ValueType::from_spec("tensor(x{})"));
TensorDataType xySparseTensorDataType(ValueType::from_spec("tensor(x{},y{})"));
-Tensor::UP createTensor(const TensorSpec &spec) {
- auto value = DefaultTensorEngine::ref().from_spec(spec);
- Tensor *tensor = dynamic_cast<Tensor*>(value.get());
- ASSERT_TRUE(tensor != nullptr);
- value.release();
- return Tensor::UP(tensor);
+vespalib::eval::Value::UP createTensor(const TensorSpec &spec) {
+ auto value = EngineOrFactory::get().from_spec(spec);
+ ASSERT_TRUE(value->is_tensor());
+ return value;
}
-std::unique_ptr<Tensor>
+std::unique_ptr<vespalib::eval::Value>
makeSimpleTensor()
{
- return makeTensor<Tensor>(TensorSpec("tensor(x{},y{})").
- add({{"x", "4"}, {"y", "5"}}, 7));
+ return makeTensor<vespalib::eval::Value>(TensorSpec("tensor(x{},y{})").
+ add({{"x", "4"}, {"y", "5"}}, 7));
}
FieldValue::UP clone(FieldValue &fv) {
diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
index c3b593732b9..2a66ea61966 100644
--- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
+++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
@@ -6,16 +6,14 @@
#include <vespa/vespalib/util/xmlstream.h>
#include <vespa/eval/eval/engine_or_factory.h>
#include <vespa/eval/eval/tensor_spec.h>
-#include <vespa/eval/tensor/tensor.h>
-#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/engine_or_factory.h>
#include <ostream>
#include <cassert>
-using vespalib::tensor::Tensor;
using vespalib::eval::EngineOrFactory;
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
-using Engine = vespalib::tensor::DefaultTensorEngine;
using namespace vespalib::xml;
namespace document {
@@ -53,7 +51,7 @@ TensorFieldValue::TensorFieldValue(const TensorFieldValue &rhs)
_altered(true)
{
if (rhs._tensor) {
- _tensor = rhs._tensor->clone();
+ _tensor = EngineOrFactory::get().copy(*rhs._tensor);
}
}
@@ -80,7 +78,7 @@ TensorFieldValue::operator=(const TensorFieldValue &rhs)
if (&_dataType == &rhs._dataType || !rhs._tensor ||
_dataType.isAssignableType(rhs._tensor->type())) {
if (rhs._tensor) {
- _tensor = rhs._tensor->clone();
+ _tensor = EngineOrFactory::get().copy(*rhs._tensor);
} else {
_tensor.reset();
}
@@ -94,7 +92,7 @@ TensorFieldValue::operator=(const TensorFieldValue &rhs)
TensorFieldValue &
-TensorFieldValue::operator=(std::unique_ptr<Tensor> rhs)
+TensorFieldValue::operator=(std::unique_ptr<vespalib::eval::Value> rhs)
{
if (!rhs || _dataType.isAssignableType(rhs->type())) {
_tensor = std::move(rhs);
@@ -111,11 +109,7 @@ TensorFieldValue::make_empty_if_not_existing()
{
if (!_tensor) {
TensorSpec empty_spec(_dataType.getTensorType().to_spec());
- auto empty_value = Engine::ref().from_spec(empty_spec);
- auto tensor_ptr = dynamic_cast<Tensor*>(empty_value.get());
- assert(tensor_ptr != nullptr);
- _tensor.reset(tensor_ptr);
- empty_value.release();
+ _tensor = EngineOrFactory::get().from_spec(empty_spec);
}
}
@@ -163,7 +157,7 @@ TensorFieldValue::print(std::ostream& out, bool verbose,
(void) indent;
out << "{TensorFieldValue: ";
if (_tensor) {
- out << Engine::ref().to_spec(*_tensor).to_string();
+ out << EngineOrFactory::get().to_spec(*_tensor).to_string();
} else {
out << "null";
}
@@ -192,7 +186,7 @@ TensorFieldValue::assign(const FieldValue &value)
void
-TensorFieldValue::assignDeserialized(std::unique_ptr<Tensor> rhs)
+TensorFieldValue::assignDeserialized(std::unique_ptr<vespalib::eval::Value> rhs)
{
if (!rhs || _dataType.isAssignableType(rhs->type())) {
_tensor = std::move(rhs);
diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
index 30cc10558b5..82a10e8aaa6 100644
--- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
+++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
@@ -5,6 +5,7 @@
#include "fieldvalue.h"
namespace vespalib { namespace tensor { class Tensor; } }
+namespace vespalib::eval { class Value; }
namespace document {
@@ -16,7 +17,7 @@ class TensorDataType;
class TensorFieldValue : public FieldValue {
private:
const TensorDataType &_dataType;
- std::unique_ptr<vespalib::tensor::Tensor> _tensor;
+ std::unique_ptr<vespalib::eval::Value> _tensor;
bool _altered;
public:
TensorFieldValue();
@@ -26,7 +27,7 @@ public:
~TensorFieldValue();
TensorFieldValue &operator=(const TensorFieldValue &rhs);
- TensorFieldValue &operator=(std::unique_ptr<vespalib::tensor::Tensor> rhs);
+ TensorFieldValue &operator=(std::unique_ptr<vespalib::eval::Value> rhs);
void make_empty_if_not_existing();
@@ -39,10 +40,10 @@ public:
const std::string& indent) const override;
virtual void printXml(XmlOutputStream& out) const override;
virtual FieldValue &assign(const FieldValue &value) override;
- const vespalib::tensor::Tensor *getAsTensorPtr() const {
+ const vespalib::eval::Value *getAsTensorPtr() const {
return _tensor.get();
}
- void assignDeserialized(std::unique_ptr<vespalib::tensor::Tensor> rhs);
+ void assignDeserialized(std::unique_ptr<vespalib::eval::Value> rhs);
virtual int compare(const FieldValue& other) const override;
DECLARE_IDENTIFIABLE(TensorFieldValue);
diff --git a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp
index 94644438f5c..6ec9c52281f 100644
--- a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp
+++ b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp
@@ -22,8 +22,8 @@
#include <vespa/vespalib/data/slime/slime.h>
#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/util/backtrace.h>
-#include <vespa/eval/tensor/tensor.h>
-#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/eval/value.h>
#include <vespa/document/util/serializableexceptions.h>
#include <vespa/document/base/exceptions.h>
#include <vespa/vespalib/objects/nbostream.h>
@@ -41,6 +41,7 @@ using vespalib::nbostream;
using vespalib::Memory;
using vespalib::stringref;
using vespalib::compression::CompressionConfig;
+using vespalib::eval::EngineOrFactory;
namespace document {
@@ -363,10 +364,10 @@ VespaDocumentDeserializer::read(TensorFieldValue &value)
throw DeserializeException(vespalib::make_string("Stream failed size(%zu), needed(%zu) to deserialize tensor field value", _stream.size(), length),
VESPA_STRLOC);
}
- std::unique_ptr<vespalib::tensor::Tensor> tensor;
+ std::unique_ptr<vespalib::eval::Value> tensor;
if (length != 0) {
nbostream wrapStream(_stream.peek(), length);
- tensor = vespalib::tensor::TypedBinaryFormat::deserialize(wrapStream);
+ tensor = EngineOrFactory::get().decode(wrapStream);
if (wrapStream.size() != 0) {
throw DeserializeException("Leftover bytes deserializing tensor field value.", VESPA_STRLOC);
}
diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
index 6d9c08578e7..882dc4e83f3 100644
--- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp
+++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
@@ -26,7 +26,8 @@
#include <vespa/document/update/fieldpathupdates.h>
#include <vespa/document/update/updates.h>
#include <vespa/document/util/bytebuffer.h>
-#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/engine_or_factory.h>
#include <vespa/vespalib/data/databuffer.h>
#include <vespa/vespalib/data/slime/binary_format.h>
#include <vespa/vespalib/objects/nbostream.h>
@@ -370,7 +371,7 @@ VespaDocumentSerializer::write(const TensorFieldValue &value) {
vespalib::nbostream tmpStream;
auto tensor = value.getAsTensorPtr();
if (tensor) {
- vespalib::tensor::TypedBinaryFormat::serialize(tmpStream, *tensor);
+ vespalib::eval::EngineOrFactory::get().encode(*tensor, tmpStream);
assert( ! tmpStream.empty());
_stream.putInt1_4Bytes(tmpStream.size());
_stream.write(tmpStream.peek(), tmpStream.size());
diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp
index d9bec7762b6..91b72329994 100644
--- a/document/src/vespa/document/update/tensor_add_update.cpp
+++ b/document/src/vespa/document/update/tensor_add_update.cpp
@@ -8,6 +8,9 @@
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
#include <vespa/document/serialization/vespadocumentdeserializer.h>
#include <vespa/document/util/serializableexceptions.h>
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/tensor/partial_update.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/stllike/asciistream.h>
@@ -17,8 +20,9 @@
using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
-using vespalib::tensor::Tensor;
using vespalib::make_string;
+using vespalib::eval::EngineOrFactory;
+using vespalib::tensor::TensorPartialUpdate;
namespace document {
@@ -78,14 +82,34 @@ TensorAddUpdate::checkCompatibility(const Field& field) const
}
}
-std::unique_ptr<Tensor>
-TensorAddUpdate::applyTo(const Tensor &tensor) const
+namespace {
+
+std::unique_ptr<vespalib::eval::Value>
+old_add(const vespalib::eval::Value *input,
+ const vespalib::eval::Value *add_cells)
+{
+ auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input);
+ assert(a);
+ auto b = dynamic_cast<const vespalib::tensor::Tensor *>(add_cells);
+ assert(b);
+ return a->add(*b);
+}
+
+} // namespace
+
+std::unique_ptr<vespalib::eval::Value>
+TensorAddUpdate::applyTo(const vespalib::eval::Value &tensor) const
{
auto addTensor = _tensor->getAsTensorPtr();
if (addTensor) {
- return tensor.add(*addTensor);
+ auto engine = EngineOrFactory::get();
+ if (engine.is_factory()) {
+ return TensorPartialUpdate::add(tensor, *addTensor, engine.factory());
+ } else {
+ return old_add(&tensor, addTensor);
+ }
}
- return std::unique_ptr<Tensor>();
+ return std::unique_ptr<vespalib::eval::Value>();
}
bool
diff --git a/document/src/vespa/document/update/tensor_add_update.h b/document/src/vespa/document/update/tensor_add_update.h
index 49519ee1ddd..8687967be49 100644
--- a/document/src/vespa/document/update/tensor_add_update.h
+++ b/document/src/vespa/document/update/tensor_add_update.h
@@ -2,7 +2,7 @@
#include "valueupdate.h"
-namespace vespalib::tensor { class Tensor; }
+namespace vespalib::eval { class Value; }
namespace document {
@@ -27,7 +27,7 @@ public:
bool operator==(const ValueUpdate &other) const override;
const TensorFieldValue &getTensor() const { return *_tensor; }
void checkCompatibility(const Field &field) const override;
- std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const;
+ std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const;
bool applyTo(FieldValue &value) const override;
void printXml(XmlOutputStream &xos) const override;
void print(std::ostream &out, bool verbose, const std::string &indent) const override;
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index 5fbdc2467b3..292b4165540 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -9,8 +9,11 @@
#include <vespa/document/serialization/vespadocumentdeserializer.h>
#include <vespa/document/util/serializableexceptions.h>
#include <vespa/eval/eval/operation.h>
-#include <vespa/eval/tensor/cell_values.h>
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/tensor/partial_update.h>
#include <vespa/eval/tensor/tensor.h>
+#include <vespa/eval/tensor/cell_values.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/util/stringfmt.h>
@@ -19,9 +22,10 @@
using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
-using vespalib::tensor::Tensor;
using vespalib::make_string;
using vespalib::eval::ValueType;
+using vespalib::eval::EngineOrFactory;
+using vespalib::tensor::TensorPartialUpdate;
using join_fun_t = double (*)(double, double);
@@ -156,16 +160,32 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const
}
}
-std::unique_ptr<Tensor>
-TensorModifyUpdate::applyTo(const Tensor &tensor) const
+
+std::unique_ptr<vespalib::eval::Value>
+old_modify(const vespalib::eval::Value *input,
+ const vespalib::eval::Value *modify_spec,
+ join_fun_t function)
+{
+ auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input);
+ // Cells tensor being sparse was validated during deserialize().
+ auto b = dynamic_cast<const vespalib::tensor::SparseTensor *>(modify_spec);
+ vespalib::tensor::CellValues cellValues(*b);
+ return a->modify(function, cellValues);
+}
+
+std::unique_ptr<vespalib::eval::Value>
+TensorModifyUpdate::applyTo(const vespalib::eval::Value &tensor) const
{
auto cellsTensor = _tensor->getAsTensorPtr();
if (cellsTensor) {
- // Cells tensor being sparse was validated during deserialize().
- vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellsTensor));
- return tensor.modify(getJoinFunction(_operation), cellValues);
+ auto engine = EngineOrFactory::get();
+ if (engine.is_factory()) {
+ return TensorPartialUpdate::modify(tensor, getJoinFunction(_operation), *cellsTensor, engine.factory());
+ } else {
+ return old_modify(&tensor, cellsTensor, getJoinFunction(_operation));
+ }
}
- return std::unique_ptr<Tensor>();
+ return std::unique_ptr<vespalib::eval::Value>();
}
bool
@@ -207,13 +227,24 @@ TensorModifyUpdate::print(std::ostream& out, bool verbose, const std::string& in
namespace {
void
-verifyCellsTensorIsSparse(const Tensor *cellsTensor)
+verifyCellsTensorIsSparse(const vespalib::eval::Value *cellsTensor)
{
- if (cellsTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor)) {
- vespalib::string err = make_string("Expected cell values tensor to be sparse, but has type '%s'",
- cellsTensor->type().to_spec().c_str());
- throw IllegalStateException(err, VESPA_STRLOC);
+ if (cellsTensor == nullptr) {
+ return;
+ }
+ auto engine = EngineOrFactory::get();
+ if (engine.is_factory()) {
+ if (cellsTensor->type().is_sparse()) {
+ return;
+ }
+ } else {
+ if (dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor)) {
+ return;
+ }
}
+ vespalib::string err = make_string("Expected cells tensor to be sparse, but has type '%s'",
+ cellsTensor->type().to_spec().c_str());
+ throw IllegalStateException(err, VESPA_STRLOC);
}
}
diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h
index c2d61d3e69b..528ff8c95e9 100644
--- a/document/src/vespa/document/update/tensor_modify_update.h
+++ b/document/src/vespa/document/update/tensor_modify_update.h
@@ -2,7 +2,7 @@
#include "valueupdate.h"
-namespace vespalib::tensor { class Tensor; }
+namespace vespalib::eval { class Value; }
namespace document {
@@ -41,7 +41,7 @@ public:
Operation getOperation() const { return _operation; }
const TensorFieldValue &getTensor() const { return *_tensor; }
void checkCompatibility(const Field &field) const override;
- std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const;
+ std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const;
bool applyTo(FieldValue &value) const override;
void printXml(XmlOutputStream &xos) const override;
void print(std::ostream &out, bool verbose, const std::string &indent) const override;
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index 34a6223e185..178bd1bd950 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -6,18 +6,23 @@
#include <vespa/document/fieldvalue/document.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
#include <vespa/document/serialization/vespadocumentdeserializer.h>
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/tensor/partial_update.h>
+#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/cell_values.h>
#include <vespa/eval/tensor/sparse/sparse_tensor.h>
-#include <vespa/eval/tensor/tensor.h>
+#include <vespa/eval/eval/value.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/xmlstream.h>
#include <ostream>
+#include <cassert>
using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
-using vespalib::tensor::Tensor;
using vespalib::make_string;
using vespalib::eval::ValueType;
+using vespalib::eval::EngineOrFactory;
+using vespalib::tensor::TensorPartialUpdate;
namespace document {
@@ -35,6 +40,16 @@ convertToCompatibleType(const TensorDataType &tensorType)
return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list), tensorType.getTensorType().cell_type()));
}
+std::unique_ptr<vespalib::eval::Value>
+old_remove(const vespalib::eval::Value *input,
+ const vespalib::eval::Value *remove_spec)
+{
+ auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input);
+ auto b = dynamic_cast<const vespalib::tensor::SparseTensor *>(remove_spec);
+ vespalib::tensor::CellValues cellAddresses(*b);
+ return a->remove(cellAddresses);
+}
+
}
IMPLEMENT_IDENTIFIABLE(TensorRemoveUpdate, ValueUpdate);
@@ -102,16 +117,19 @@ TensorRemoveUpdate::checkCompatibility(const Field &field) const
}
}
-std::unique_ptr<Tensor>
-TensorRemoveUpdate::applyTo(const Tensor &tensor) const
+std::unique_ptr<vespalib::eval::Value>
+TensorRemoveUpdate::applyTo(const vespalib::eval::Value &tensor) const
{
auto addressTensor = _tensor->getAsTensorPtr();
if (addressTensor) {
- // Address tensor being sparse was validated during deserialize().
- vespalib::tensor::CellValues cellAddresses(static_cast<const vespalib::tensor::SparseTensor &>(*addressTensor));
- return tensor.remove(cellAddresses);
+ auto engine = EngineOrFactory::get();
+ if (engine.is_factory()) {
+ return TensorPartialUpdate::remove(tensor, *addressTensor, engine.factory());
+ } else {
+ return old_remove(&tensor, addressTensor);
+ }
}
- return std::unique_ptr<Tensor>();
+ return std::unique_ptr<vespalib::eval::Value>();
}
bool
@@ -153,15 +171,27 @@ TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &in
namespace {
void
-verifyAddressTensorIsSparse(const Tensor *addressTensor)
+verifyAddressTensorIsSparse(const vespalib::eval::Value *addressTensor)
{
- if (addressTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor)) {
- vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'",
- addressTensor->type().to_spec().c_str());
- throw IllegalStateException(err, VESPA_STRLOC);
+ if (addressTensor == nullptr) {
+ return;
+ }
+ auto engine = EngineOrFactory::get();
+ if (engine.is_factory()) {
+ if (addressTensor->type().is_sparse()) {
+ return;
+ }
+ } else {
+ if (dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor)) {
+ return;
+ }
}
+ vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'",
+ addressTensor->type().to_spec().c_str());
+ throw IllegalStateException(err, VESPA_STRLOC);
}
+
}
void
diff --git a/document/src/vespa/document/update/tensor_remove_update.h b/document/src/vespa/document/update/tensor_remove_update.h
index e75348fa829..6ab66048dd4 100644
--- a/document/src/vespa/document/update/tensor_remove_update.h
+++ b/document/src/vespa/document/update/tensor_remove_update.h
@@ -2,7 +2,7 @@
#include "valueupdate.h"
-namespace vespalib::tensor { class Tensor; }
+namespace vespalib::eval { class Value; }
namespace document {
@@ -30,7 +30,7 @@ public:
TensorRemoveUpdate &operator=(const TensorRemoveUpdate &rhs);
TensorRemoveUpdate &operator=(TensorRemoveUpdate &&rhs);
const TensorFieldValue &getTensor() const { return *_tensor; }
- std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const;
+ std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const;
bool operator==(const ValueUpdate &other) const override;
void checkCompatibility(const Field &field) const override;