summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@broadpark.no>2019-03-11 19:26:18 +0100
committerTor Egge <Tor.Egge@broadpark.no>2019-03-11 19:26:18 +0100
commit8dd6e095f7f526e5a85db146415abfd9bfa635ff (patch)
tree23c87841f2be4832afb0528bbc17fff66df02bc9 /searchlib
parent48ce50681ad29a1a17446dbb1f0413615ca35725 (diff)
Stop using tensor mapper when setting values in tensor attribute.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp36
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp5
-rw-r--r--searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp5
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp69
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_attribute.h4
5 files changed, 79 insertions, 40 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index d041dde52a5..2e339a069b6 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/document/base/exceptions.h>
#include <vespa/searchlib/tensor/tensor_attribute.h>
#include <vespa/searchlib/tensor/generic_tensor_attribute.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
@@ -13,6 +14,7 @@
#include <vespa/log/log.h>
LOG_SETUP("tensorattribute_test");
+using document::WrongTensorTypeException;
using search::tensor::TensorAttribute;
using search::tensor::DenseTensorAttribute;
using search::tensor::GenericTensorAttribute;
@@ -258,10 +260,14 @@ Fixture::testSetTensorValue()
EXPECT_EQUAL(5u, _attr->getNumDocs());
EXPECT_EQUAL(5u, _attr->getCommittedDocIdLimit());
TEST_DO(assertGetNoTensor(4));
- setTensor(4, *createTensor({}, {}));
+ EXPECT_EXCEPTION(setTensor(4, *createTensor({}, {})),
+ WrongTensorTypeException,
+ "but other tensor type is 'double'");
+ TEST_DO(assertGetNoTensor(4));
+ setTensor(4, *_tensorAttr->getEmptyTensor());
if (_denseTensors) {
TEST_DO(assertGetTensor(*expEmptyDenseTensor(), 4));
- setTensor(3, *createTensor({ {{{"y","1"}}, 11} }, { "x", "y"}));
+ setTensor(3, *expDenseTensor3());
TEST_DO(assertGetTensor(*expDenseTensor3(), 3));
} else {
TEST_DO(assertGetTensor({}, {"x", "y"}, 4));
@@ -277,8 +283,12 @@ void
Fixture::testSaveLoad()
{
ensureSpace(4);
- setTensor(4, *createTensor({}, {}));
- setTensor(3, *createTensor({ {{{"y","1"}}, 11} }, { "x", "y"}));
+ setTensor(4, *_tensorAttr->getEmptyTensor());
+ if (_denseTensors) {
+ setTensor(3, *expDenseTensor3());
+ } else {
+ setTensor(3, *createTensor({ {{{"y","1"}}, 11} }, { "x", "y"}));
+ }
TEST_DO(save());
TEST_DO(load());
EXPECT_EQUAL(5u, _attr->getNumDocs());
@@ -298,10 +308,15 @@ void
Fixture::testCompaction()
{
ensureSpace(4);
- Tensor::UP emptytensor = createTensor({}, {});
+ Tensor::UP emptytensor = _tensorAttr->getEmptyTensor();
Tensor::UP emptyxytensor = createTensor({}, {"x", "y"});
Tensor::UP simpletensor = createTensor({ {{{"y","1"}}, 11} }, { "x", "y"});
Tensor::UP filltensor = createTensor({ {{}, 5} }, { "x", "y"});
+ if (_denseTensors) {
+ emptyxytensor = expEmptyDenseTensor();
+ simpletensor = expDenseTensor3();
+ filltensor = expDenseFillTensor();
+ }
setTensor(4, *emptytensor);
setTensor(3, *simpletensor);
setTensor(2, *filltensor);
@@ -325,11 +340,6 @@ Fixture::testCompaction()
"iter = %" PRIu64 ", memory usage %" PRIu64 ", -> %" PRIu64,
iter, oldStatus.getUsed(), newStatus.getUsed());
TEST_DO(assertGetNoTensor(1));
- if (_denseTensors) {
- emptyxytensor = expEmptyDenseTensor();
- simpletensor = expDenseTensor3();
- filltensor = expDenseFillTensor();
- }
TEST_DO(assertGetTensor(*filltensor, 2));
TEST_DO(assertGetTensor(*simpletensor, 3));
TEST_DO(assertGetTensor(*emptyxytensor, 4));
@@ -371,12 +381,6 @@ Fixture::testEmptyTensor()
}
-TEST_F("Test empty sparse tensor attribute", Fixture("tensor()"))
-{
- f.testEmptyAttribute();
-}
-
-
template <class MakeFixture>
void testAll(MakeFixture &&f)
{
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
index 13091241809..ba4c64f1744 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
@@ -14,7 +14,6 @@ LOG_SETUP(".searchlib.tensor.dense_tensor_attribute");
using vespalib::eval::ValueType;
using vespalib::tensor::MutableDenseTensorView;
using vespalib::tensor::Tensor;
-using vespalib::tensor::TensorMapper;
namespace search::tensor {
@@ -100,8 +99,8 @@ DenseTensorAttribute::~DenseTensorAttribute()
void
DenseTensorAttribute::setTensor(DocId docId, const Tensor &tensor)
{
- EntryRef ref = _denseTensorStore.setTensor(
- (_tensorMapper ? *_tensorMapper->map(tensor) : tensor));
+ checkTensorType(tensor);
+ EntryRef ref = _denseTensorStore.setTensor(tensor);
setTensorRef(docId, ref);
}
diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
index 3ea40433a82..7b5a22f6966 100644
--- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
@@ -11,7 +11,6 @@
using vespalib::eval::ValueType;
using vespalib::tensor::Tensor;
-using vespalib::tensor::TensorMapper;
namespace search::tensor {
@@ -49,8 +48,8 @@ GenericTensorAttribute::~GenericTensorAttribute()
void
GenericTensorAttribute::setTensor(DocId docId, const Tensor &tensor)
{
- EntryRef ref = _genericTensorStore.setTensor(
- (_tensorMapper ? *_tensorMapper->map(tensor) : tensor));
+ checkTensorType(tensor);
+ EntryRef ref = _genericTensorStore.setTensor(tensor);
setTensorRef(docId, ref);
}
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp
index af8c820a3c6..39225e867f8 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp
@@ -1,12 +1,22 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "tensor_attribute.h"
-#include <vespa/eval/tensor/default_tensor.h>
+#include <vespa/document/base/exceptions.h>
+#include <vespa/document/datatype/tensor_data_type.h>
+#include <vespa/eval/eval/simple_tensor.h>
+#include <vespa/eval/tensor/dense/dense_tensor.h>
+#include <vespa/eval/tensor/sparse/sparse_tensor.h>
+#include <vespa/eval/tensor/wrapped_simple_tensor.h>
#include <vespa/searchlib/common/rcuvector.hpp>
+using vespalib::eval::SimpleTensor;
using vespalib::eval::ValueType;
using vespalib::tensor::Tensor;
-using vespalib::tensor::TensorMapper;
+using vespalib::tensor::DenseTensor;
+using vespalib::tensor::SparseTensor;
+using vespalib::tensor::WrappedSimpleTensor;
+using document::TensorDataType;
+using document::WrongTensorTypeException;
namespace search {
@@ -20,20 +30,41 @@ constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0;
constexpr size_t DEAD_SLACK = 0x10000u;
-Tensor::UP
-createEmptyTensor(const TensorMapper *mapper)
+ValueType
+createEmptyTensorType(const ValueType &type)
{
- vespalib::tensor::DefaultTensor::builder builder;
- if (mapper != nullptr) {
- return mapper->map(*builder.build());
+ std::vector<ValueType::Dimension> list;
+ for (const auto &dim : type.dimensions()) {
+ if (dim.is_indexed() && !dim.is_bound()) {
+ list.emplace_back(dim.name, 1);
+ } else {
+ list.emplace_back(dim);
+ }
}
- return builder.build();
+ return ValueType::tensor_type(std::move(list));
}
-bool
-shouldCreateMapper(const ValueType &tensorType)
+Tensor::UP
+createEmptyTensor(const ValueType &type)
+{
+ if (type.is_sparse()) {
+ return std::make_unique<SparseTensor>(type, SparseTensor::Cells());
+ } else if (type.is_dense()) {
+ size_t size = 1;
+ for (const auto &dimension : type.dimensions()) {
+ size *= dimension.size;
+ }
+ return std::make_unique<DenseTensor>(type, DenseTensor::Cells(size));
+ } else {
+ return std::make_unique<WrappedSimpleTensor>(std::make_unique<SimpleTensor>(type, SimpleTensor::Cells()));
+ }
+}
+
+vespalib::string makeWrongTensorTypeMsg(const ValueType &fieldTensorType, const ValueType &tensorType)
{
- return tensorType.is_tensor() && !tensorType.dimensions().empty();
+ return vespalib::make_string("Field tensor type is '%s' but other tensor type is '%s'",
+ fieldTensorType.to_spec().c_str(),
+ tensorType.to_spec().c_str());
}
}
@@ -45,12 +76,9 @@ TensorAttribute::TensorAttribute(vespalib::stringref name, const Config &cfg, Te
cfg.getGrowStrategy().getDocsGrowDelta(),
getGenerationHolder()),
_tensorStore(tensorStore),
- _tensorMapper(),
+ _emptyTensor(createEmptyTensor(createEmptyTensorType(cfg.tensorType()))),
_compactGeneration(0)
{
- if (shouldCreateMapper(cfg.tensorType())) {
- _tensorMapper = std::make_unique<TensorMapper>(cfg.tensorType());
- }
}
@@ -140,6 +168,15 @@ TensorAttribute::addDoc(DocId &docId)
return true;
}
+void
+TensorAttribute::checkTensorType(const Tensor &tensor)
+{
+ const ValueType &fieldTensorType = getConfig().tensorType();
+ const ValueType &tensorType = tensor.type();
+ if (!TensorDataType::isAssignableType(fieldTensorType, tensorType)) {
+ throw WrongTensorTypeException(makeWrongTensorTypeMsg(fieldTensorType, tensorType), VESPA_STRLOC);
+ }
+}
void
TensorAttribute::setTensorRef(DocId docId, EntryRef ref)
@@ -159,7 +196,7 @@ TensorAttribute::setTensorRef(DocId docId, EntryRef ref)
Tensor::UP
TensorAttribute::getEmptyTensor() const
{
- return createEmptyTensor(_tensorMapper.get());
+ return _emptyTensor->clone();
}
vespalib::eval::ValueType
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h
index 06758da8063..4241a899018 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h
@@ -6,7 +6,6 @@
#include <vespa/searchlib/attribute/not_implemented_attribute.h>
#include "tensor_store.h"
#include <vespa/searchlib/common/rcuvector.h>
-#include <vespa/eval/tensor/tensor_mapper.h>
namespace search::tensor {
@@ -21,11 +20,12 @@ protected:
RefVector _refVector; // docId -> ref in data store for serialized tensor
TensorStore &_tensorStore; // data store for serialized tensors
- std::unique_ptr<vespalib::tensor::TensorMapper> _tensorMapper; // mapper to our tensor type
+ std::unique_ptr<Tensor> _emptyTensor;
uint64_t _compactGeneration; // Generation when last compact occurred
template <typename RefType>
void doCompactWorst();
+ void checkTensorType(const Tensor &tensor);
void setTensorRef(DocId docId, EntryRef ref);
public:
DECLARE_IDENTIFIABLE_ABSTRACT(TensorAttribute);