diff options
author | Tor Egge <Tor.Egge@oath.com> | 2018-01-18 10:44:11 +0000 |
---|---|---|
committer | Tor Egge <Tor.Egge@oath.com> | 2018-01-18 10:56:09 +0000 |
commit | 6d32c06e24b0c82785ca3a68175c7aeea0861fe9 (patch) | |
tree | 717ea6698a4cbed24bff01138a17ca01e6a7bd90 | |
parent | 64a0f2929b91415077a78d8d7c509f8df3f535a6 (diff) |
Add read interface for tensor attribute, used by tensor attribute feature
executors.
11 files changed, 66 insertions, 27 deletions
diff --git a/searchlib/src/vespa/searchlib/features/attributefeature.cpp b/searchlib/src/vespa/searchlib/features/attributefeature.cpp index 5f03cda1869..cd7a27b37b5 100644 --- a/searchlib/src/vespa/searchlib/features/attributefeature.cpp +++ b/searchlib/src/vespa/searchlib/features/attributefeature.cpp @@ -24,7 +24,7 @@ using search::attribute::ConstCharContent; using search::tensor::DenseTensorAttribute; using search::attribute::IntegerContent; using search::attribute::FloatContent; -using search::tensor::TensorAttribute; +using search::tensor::ITensorAttribute; using search::attribute::WeightedConstCharContent; using search::attribute::WeightedIntegerContent; using search::attribute::WeightedFloatContent; @@ -391,24 +391,22 @@ createTensorAttributeExecutor(const IAttributeVector *attribute, const vespalib: " Returning empty tensor.", attribute->getName().c_str()); return ConstantTensorExecutor::createEmpty(tensorType, stash); } - const TensorAttribute *tensorAttribute = dynamic_cast<const TensorAttribute *>(attribute); + const ITensorAttribute *tensorAttribute = dynamic_cast<const ITensorAttribute *>(attribute); if (tensorAttribute == nullptr) { LOG(warning, "The attribute vector '%s' could not be converted to a tensor attribute." " Returning empty tensor.", attribute->getName().c_str()); return ConstantTensorExecutor::createEmpty(tensorType, stash); } - if (tensorType != tensorAttribute->getConfig().tensorType()) { + if (tensorType != tensorAttribute->getTensorType()) { LOG(warning, "The tensor attribute '%s' has tensor type '%s'," " while the feature executor expects type '%s'. Returning empty tensor.", - tensorAttribute->getName().c_str(), - tensorAttribute->getConfig().tensorType().to_spec().c_str(), + attribute->getName().c_str(), + tensorAttribute->getTensorType().to_spec().c_str(), tensorType.to_spec().c_str()); return ConstantTensorExecutor::createEmpty(tensorType, stash); } if (tensorType.is_dense()) { - const DenseTensorAttribute *denseTensorAttribute = dynamic_cast<const DenseTensorAttribute *>(tensorAttribute); - assert(denseTensorAttribute != nullptr); - return stash.create<DenseTensorAttributeExecutor>(denseTensorAttribute); + return stash.create<DenseTensorAttributeExecutor>(tensorAttribute); } return stash.create<TensorAttributeExecutor>(tensorAttribute); } diff --git a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp index 487bc724e07..1f554cb9af7 100644 --- a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp +++ b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp @@ -1,9 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "dense_tensor_attribute_executor.h" -#include <vespa/searchlib/tensor/dense_tensor_attribute.h> +#include <vespa/searchlib/tensor/i_tensor_attribute.h> -using search::tensor::DenseTensorAttribute; +using search::tensor::ITensorAttribute; using vespalib::eval::Tensor; using vespalib::tensor::MutableDenseTensorView; @@ -11,9 +11,9 @@ namespace search { namespace features { DenseTensorAttributeExecutor:: -DenseTensorAttributeExecutor(const DenseTensorAttribute *attribute) +DenseTensorAttributeExecutor(const ITensorAttribute *attribute) : _attribute(attribute), - _tensorView(_attribute->getConfig().tensorType()) + _tensorView(_attribute->getTensorType()) { } diff --git a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h index ac3d327c12a..3b66d3d0ba1 100644 --- a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h +++ b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h @@ -7,7 +7,7 @@ #include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h> namespace search { -namespace tensor { class DenseTensorAttribute; } +namespace tensor { class ITensorAttribute; } namespace features { /** @@ -17,11 +17,11 @@ namespace features { class DenseTensorAttributeExecutor : public fef::FeatureExecutor { private: - const search::tensor::DenseTensorAttribute *_attribute; + const search::tensor::ITensorAttribute *_attribute; vespalib::tensor::MutableDenseTensorView _tensorView; public: - DenseTensorAttributeExecutor(const search::tensor::DenseTensorAttribute *attribute); + DenseTensorAttributeExecutor(const search::tensor::ITensorAttribute *attribute); void execute(uint32_t docId) override; }; diff --git a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp index 03393d6f590..51727846f95 100644 --- a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp +++ b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp @@ -1,13 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensor_attribute_executor.h" -#include <vespa/searchlib/tensor/tensor_attribute.h> +#include <vespa/searchlib/tensor/i_tensor_attribute.h> namespace search { namespace features { TensorAttributeExecutor:: -TensorAttributeExecutor(const search::tensor::TensorAttribute *attribute) +TensorAttributeExecutor(const search::tensor::ITensorAttribute *attribute) : _attribute(attribute), _emptyTensor(attribute->getEmptyTensor()), _tensor() diff --git a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h index 0f1e21c8cad..3f981e06482 100644 --- a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h +++ b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h @@ -9,18 +9,18 @@ #include <vespa/eval/tensor/default_tensor.h> namespace search { -namespace tensor { class TensorAttribute; } +namespace tensor { class ITensorAttribute; } namespace features { class TensorAttributeExecutor : public fef::FeatureExecutor { private: - const search::tensor::TensorAttribute *_attribute; + const search::tensor::ITensorAttribute *_attribute; std::unique_ptr<vespalib::eval::Tensor> _emptyTensor; std::unique_ptr<vespalib::eval::Tensor> _tensor; public: - TensorAttributeExecutor(const search::tensor::TensorAttribute *attribute); + TensorAttributeExecutor(const search::tensor::ITensorAttribute *attribute); void execute(uint32_t docId) override; }; diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h index d3bbfbab300..68ca6ae7295 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h @@ -23,11 +23,11 @@ public: virtual ~DenseTensorAttribute(); virtual void setTensor(DocId docId, const Tensor &tensor) override; virtual std::unique_ptr<Tensor> getTensor(DocId docId) const override; + virtual void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override; virtual bool onLoad() override; virtual std::unique_ptr<AttributeSaver> onInitSave() override; virtual void compactWorst() override; virtual uint32_t getVersion() const override; - void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const; }; diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp index 07377630299..76ce89d9b45 100644 --- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp @@ -70,6 +70,12 @@ GenericTensorAttribute::getTensor(DocId docId) const return _genericTensorStore.getTensor(ref); } +void +GenericTensorAttribute::getTensor(DocId, vespalib::tensor::MutableDenseTensorView &) const +{ + notImplemented(); +} + bool GenericTensorAttribute::onLoad() { diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.h index f7d7184d8ec..948e72cd41a 100644 --- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.h @@ -20,6 +20,7 @@ public: virtual ~GenericTensorAttribute(); virtual void setTensor(DocId docId, const Tensor &tensor) override; virtual std::unique_ptr<Tensor> getTensor(DocId docId) const override; + virtual void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override; virtual bool onLoad() override; virtual std::unique_ptr<AttributeSaver> onInitSave() override; virtual void compactWorst() override; diff --git a/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h new file mode 100644 index 00000000000..6c83d3caae9 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/i_tensor_attribute.h @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <memory> + +namespace vespalib::tensor { +class MutableDenseTensorView; +class Tensor; +} +namespace vespalib::eval { class ValueType; } + +namespace search::tensor { + +/** + * Interface for tensor attribute used by feature executors to get information. + */ +class ITensorAttribute +{ +public: + using Tensor = vespalib::tensor::Tensor; + + virtual ~ITensorAttribute() {} + virtual std::unique_ptr<Tensor> getTensor(uint32_t docId) const = 0; + virtual std::unique_ptr<Tensor> getEmptyTensor() const = 0; + virtual void getTensor(uint32_t docId, vespalib::tensor::MutableDenseTensorView &tensor) const = 0; + virtual vespalib::eval::ValueType getTensorType() const = 0; +}; + +} // namespace search::tensor diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp index c2b51db004e..0ca0d2fa776 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp @@ -162,6 +162,12 @@ TensorAttribute::getEmptyTensor() const return createEmptyTensor(_tensorMapper.get()); } +vespalib::eval::ValueType +TensorAttribute::getTensorType() const +{ + return getConfig().tensorType(); +} + void TensorAttribute::clearDocs(DocId lidLow, DocId lidLimit) { diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h index 758b115861b..35dbb3ab21b 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h @@ -2,13 +2,12 @@ #pragma once +#include "i_tensor_attribute.h" #include <vespa/searchlib/attribute/not_implemented_attribute.h> #include "tensor_store.h" #include <vespa/searchlib/common/rcuvector.h> #include <vespa/eval/tensor/tensor_mapper.h> -namespace vespalib { namespace tensor { class Tensor; } } - namespace search { namespace tensor { @@ -16,7 +15,7 @@ namespace tensor { /** * Attribute vector class used to store tensors for all documents in memory. */ -class TensorAttribute : public NotImplementedAttribute +class TensorAttribute : public NotImplementedAttribute, public ITensorAttribute { protected: using RefType = TensorStore::EntryRef; @@ -33,7 +32,6 @@ protected: public: DECLARE_IDENTIFIABLE_ABSTRACT(TensorAttribute); using RefCopyVector = vespalib::Array<RefType>; - using Tensor = vespalib::tensor::Tensor; TensorAttribute(const vespalib::stringref &baseFileName, const Config &cfg, TensorStore &tensorStore); virtual ~TensorAttribute(); @@ -43,13 +41,13 @@ public: virtual void removeOldGenerations(generation_t firstUsed) override; virtual void onGenerationChange(generation_t generation) override; virtual bool addDoc(DocId &docId) override; - std::unique_ptr<Tensor> getEmptyTensor() const; + virtual std::unique_ptr<Tensor> getEmptyTensor() const override; + virtual vespalib::eval::ValueType getTensorType() const override; virtual void clearDocs(DocId lidLow, DocId lidLimit) override; virtual void onShrinkLidSpace() override; virtual uint32_t getVersion() const override; RefCopyVector getRefCopy() const; virtual void setTensor(DocId docId, const Tensor &tensor) = 0; - virtual std::unique_ptr<Tensor> getTensor(DocId docId) const = 0; virtual void compactWorst() = 0; }; |