diff options
author | Tor Egge <Tor.Egge@yahooinc.com> | 2023-03-31 19:34:50 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-31 19:34:50 +0200 |
commit | 9d93baa2b9d23258f8e760e3d804ee9065cf9a58 (patch) | |
tree | 594da70ff2a65a76bd96735c55291b261e4fa9a5 /searchlib/src/vespa | |
parent | e0db5db519c291dc9ea9ec994b51fb9499f1e246 (diff) | |
parent | fe96fab936a7f6f920aafca24397febbb556219a (diff) |
Merge pull request #26665 from vespa-engine/toregge/add-tensor-ext-attributev8.149.36
Add TensorExtAttribute.
Diffstat (limited to 'searchlib/src/vespa')
4 files changed, 239 insertions, 0 deletions
diff --git a/searchlib/src/vespa/searchlib/attribute/attributevector.h b/searchlib/src/vespa/searchlib/attribute/attributevector.h index 3d14622ca02..e40785911ea 100644 --- a/searchlib/src/vespa/searchlib/attribute/attributevector.h +++ b/searchlib/src/vespa/searchlib/attribute/attributevector.h @@ -38,6 +38,8 @@ namespace vespalib::alloc { class Alloc; } +namespace vespalib::eval { struct Value; } + namespace search { template <typename T> class ComponentGuard; @@ -86,6 +88,7 @@ public: virtual bool add(double, int32_t = 1) { return false; } virtual bool add(const char *, int32_t = 1) { return false; } virtual bool add(vespalib::ConstArrayRef<char>, int32_t = 1) { return false; } + virtual bool add(const vespalib::eval::Value&, int32_t = 1) { return false; } virtual ~IExtendAttribute() = default; }; diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt index c8c5d4d4257..313863d8dcb 100644 --- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt @@ -40,6 +40,7 @@ vespa_add_library(searchlib_tensor OBJECT tensor_buffer_store.cpp tensor_buffer_type_mapper.cpp tensor_deserialize.cpp + tensor_ext_attribute.cpp tensor_store.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp new file mode 100644 index 00000000000..19c8cf6053b --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.cpp @@ -0,0 +1,181 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "tensor_ext_attribute.h" +#include "serialized_tensor_ref.h" +#include "vector_bundle.h" +#include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/tensor_spec.h> +#include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/value_codec.h> +#include <vespa/searchcommon/attribute/config.h> + +#include <vespa/log/log.h> +LOG_SETUP(".searchlib.tensor.tensor_ext_attribute"); + +using vespalib::eval::FastValueBuilderFactory; +using vespalib::eval::TensorSpec; +using vespalib::eval::TypedCells; +using vespalib::eval::Value; +using vespalib::eval::ValueType; + +namespace search::tensor { + +namespace { + +std::unique_ptr<Value> +create_empty_tensor(const ValueType& type) +{ + const auto &factory = FastValueBuilderFactory::get(); + TensorSpec empty_spec(type.to_spec()); + return vespalib::eval::value_from_spec(empty_spec, factory); +} + +} + +TensorExtAttribute::TensorExtAttribute(const vespalib::string& name, const Config& cfg) + : NotImplementedAttribute(name, cfg), + ITensorAttribute(), + IExtendAttribute(), + _subspace_type(cfg.tensorType()), + _empty(_subspace_type), + _empty_tensor(create_empty_tensor(cfg.tensorType())) +{ +} + +TensorExtAttribute::~TensorExtAttribute() = default; + +const ITensorAttribute* +TensorExtAttribute::asTensorAttribute() const +{ + return this; +} + +void +TensorExtAttribute::onCommit() +{ + LOG_ABORT("should not be reached"); +} + +void +TensorExtAttribute::onUpdateStat() +{ +} + +bool +TensorExtAttribute::addDoc(DocId& docId) +{ + docId = _data.size(); + _data.emplace_back(nullptr); + incNumDocs(); + setCommittedDocIdLimit(getNumDocs()); + return true; +} + +bool +TensorExtAttribute::add(const vespalib::eval::Value& v, int32_t) +{ + _data.back() = &v; + return true; +} + +IExtendAttribute* +TensorExtAttribute::getExtendInterface() +{ + return this; +} + +TypedCells +TensorExtAttribute::get_vector(uint32_t docid, uint32_t subspace) const +{ + auto vectors = get_vectors(docid); + return (subspace < vectors.subspaces()) ? vectors.cells(subspace) : _empty.cells(); +} + +VectorBundle +TensorExtAttribute::get_vectors(uint32_t docid) const +{ + auto tensor = _data[docid]; + if (tensor == nullptr) { + return VectorBundle(); + } + return VectorBundle(tensor->cells().data, tensor->index().size(), _subspace_type); +} + +std::unique_ptr<Value> +TensorExtAttribute::getTensor(uint32_t docid) const +{ + auto tensor = _data[docid]; + if (tensor == nullptr) { + return {}; + } + return FastValueBuilderFactory::get().copy(*tensor); +} + +std::unique_ptr<Value> +TensorExtAttribute::getEmptyTensor() const +{ + return FastValueBuilderFactory::get().copy(*_empty_tensor); +} + +TypedCells +TensorExtAttribute::extract_cells_ref(uint32_t docid) const +{ + return get_vector(docid, 0); +} + +const vespalib::eval::Value& +TensorExtAttribute::get_tensor_ref(uint32_t docid) const +{ + auto tensor = _data[docid]; + return (tensor == nullptr) ? *_empty_tensor : *tensor; +} + +SerializedTensorRef +TensorExtAttribute::get_serialized_tensor_ref(uint32_t) const +{ + notImplemented(); +} + +bool +TensorExtAttribute::supports_extract_cells_ref() const +{ + return getConfig().tensorType().is_dense(); +} + +bool +TensorExtAttribute::supports_get_tensor_ref() const +{ + return true; +} + +bool +TensorExtAttribute::supports_get_serialized_tensor_ref() const +{ + return false; +} + +const ValueType& +TensorExtAttribute::getTensorType() const +{ + return getConfig().tensorType(); +} + +TensorExtAttribute::DistanceMetric +TensorExtAttribute::distance_metric() const +{ + return getConfig().distance_metric(); +} + +uint32_t +TensorExtAttribute::get_num_docs() const +{ + return _data.size(); +} + +void +TensorExtAttribute::get_state(const vespalib::slime::Inserter& inserter) const +{ + (void) inserter; +} + +} diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h new file mode 100644 index 00000000000..a58426cd146 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/tensor_ext_attribute.h @@ -0,0 +1,54 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "i_tensor_attribute.h" +#include "empty_subspace.h" +#include "subspace_type.h" +#include <vespa/searchlib/attribute/not_implemented_attribute.h> +#include <vespa/vespalib/stllike/allocator.h> + +namespace search::tensor { + +/** + * Attribute vector storing a pointer to single tensor value per + * document in streaming search. The tensor is not owned by this + * attribute vector. + */ +class TensorExtAttribute : public NotImplementedAttribute, + public ITensorAttribute, + public IExtendAttribute +{ + std::vector<const vespalib::eval::Value*> _data; + SubspaceType _subspace_type; + EmptySubspace _empty; + std::unique_ptr<vespalib::eval::Value> _empty_tensor; +public: + TensorExtAttribute(const vespalib::string& name, const Config& cfg); + ~TensorExtAttribute() override; + const ITensorAttribute* asTensorAttribute() const override; + void onCommit() override; + void onUpdateStat() override; + bool addDoc(DocId& docId) override; + bool add(const vespalib::eval::Value& v, int32_t) override; + IExtendAttribute* getExtendInterface() override; + // DocVectorAccess API + vespalib::eval::TypedCells get_vector(uint32_t docid, uint32_t subspace) const override; + VectorBundle get_vectors(uint32_t docid) const override; + + // ITensorAttribute API + std::unique_ptr<vespalib::eval::Value> getTensor(uint32_t docid) const override; + std::unique_ptr<vespalib::eval::Value> getEmptyTensor() const override; + vespalib::eval::TypedCells extract_cells_ref(uint32_t docid) const override; + const vespalib::eval::Value& get_tensor_ref(uint32_t docid) const override; + SerializedTensorRef get_serialized_tensor_ref(uint32_t docid) const override; + bool supports_extract_cells_ref() const override; + bool supports_get_tensor_ref() const override; + bool supports_get_serialized_tensor_ref() const override; + const vespalib::eval::ValueType & getTensorType() const override; + search::attribute::DistanceMetric distance_metric() const override; + uint32_t get_num_docs() const override; + void get_state(const vespalib::slime::Inserter& inserter) const override; +}; + +} |