diff options
author | Geir Storli <geirst@verizonmedia.com> | 2020-06-12 14:07:49 +0000 |
---|---|---|
committer | Geir Storli <geirst@verizonmedia.com> | 2020-06-17 13:17:54 +0000 |
commit | 5566148a0ad253569e44b4e21ada7e8e59241eaf (patch) | |
tree | 43434e8a049323a2c956b23539a2d969f92cefb8 | |
parent | e55f4b911fd3dab1270514edc99b5e3d3a086833 (diff) |
Implement initial support for two-phase puts in attribute writer.
This is only turned on for tensor attributes with a hnsw index that allows multi-threaded indexing.
10 files changed, 368 insertions, 47 deletions
diff --git a/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h b/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h index 3e3683ce60f..c8b196023d6 100644 --- a/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h +++ b/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h @@ -16,24 +16,29 @@ private: uint32_t _neighbors_to_explore_at_insert; // This is always the same as in the attribute config, and is duplicated here to simplify usage. DistanceMetric _distance_metric; + bool _allow_multi_threaded_indexing; public: HnswIndexParams(uint32_t max_links_per_node_in, uint32_t neighbors_to_explore_at_insert_in, - DistanceMetric distance_metric_in) + DistanceMetric distance_metric_in, + bool allow_multi_threaded_indexing_in = false) : _max_links_per_node(max_links_per_node_in), _neighbors_to_explore_at_insert(neighbors_to_explore_at_insert_in), - _distance_metric(distance_metric_in) + _distance_metric(distance_metric_in), + _allow_multi_threaded_indexing(allow_multi_threaded_indexing_in) {} uint32_t max_links_per_node() const { return _max_links_per_node; } uint32_t neighbors_to_explore_at_insert() const { return _neighbors_to_explore_at_insert; } DistanceMetric distance_metric() const { return _distance_metric; } + bool allow_multi_threaded_indexing() const { return _allow_multi_threaded_indexing; } bool operator==(const HnswIndexParams& rhs) const { return (_max_links_per_node == rhs._max_links_per_node && _neighbors_to_explore_at_insert == rhs._neighbors_to_explore_at_insert && - _distance_metric == rhs._distance_metric); + _distance_metric == rhs._distance_metric && + _allow_multi_threaded_indexing == rhs._allow_multi_threaded_indexing); } }; diff --git a/searchcommon/src/vespa/searchcommon/attribute/status.cpp b/searchcommon/src/vespa/searchcommon/attribute/status.cpp index da13548ec2e..f2bb49c348a 100644 --- a/searchcommon/src/vespa/searchcommon/attribute/status.cpp +++ b/searchcommon/src/vespa/searchcommon/attribute/status.cpp @@ -20,6 +20,42 @@ Status::Status() { } +Status::Status(const Status& rhs) + : _numDocs(rhs._numDocs), + _numValues(rhs._numValues), + _numUniqueValues(rhs._numUniqueValues), + _allocated(rhs._allocated), + _used(rhs._used), + _dead(rhs._dead), + _unused(rhs._unused), + _onHold(rhs._onHold), + _onHoldMax(rhs._onHoldMax), + _lastSyncToken(rhs.getLastSyncToken()), + _updates(rhs._updates), + _nonIdempotentUpdates(rhs._nonIdempotentUpdates), + _bitVectors(rhs._bitVectors) +{ +} + +Status& +Status::operator=(const Status& rhs) +{ + _numDocs = rhs._numDocs; + _numValues = rhs._numValues; + _numUniqueValues = rhs._numUniqueValues; + _allocated = rhs._allocated; + _used = rhs._used; + _dead = rhs._dead; + _unused = rhs._unused; + _onHold = rhs._onHold; + _onHoldMax = rhs._onHoldMax; + setLastSyncToken(rhs.getLastSyncToken()); + _updates = rhs._updates; + _nonIdempotentUpdates = rhs._nonIdempotentUpdates; + _bitVectors = rhs._bitVectors; + return *this; +} + vespalib::string Status::createName(vespalib::stringref index, vespalib::stringref attr) { diff --git a/searchcommon/src/vespa/searchcommon/attribute/status.h b/searchcommon/src/vespa/searchcommon/attribute/status.h index 888355b3f58..a624309da65 100644 --- a/searchcommon/src/vespa/searchcommon/attribute/status.h +++ b/searchcommon/src/vespa/searchcommon/attribute/status.h @@ -3,6 +3,7 @@ #pragma once #include <vespa/vespalib/stllike/string.h> +#include <atomic> namespace search::attribute { @@ -10,6 +11,8 @@ class Status { public: Status(); + Status(const Status& rhs); + Status& operator=(const Status& rhs); void updateStatistics(uint64_t numValues, uint64_t numUniqueValue, uint64_t allocated, uint64_t used, uint64_t dead, uint64_t onHold); @@ -22,14 +25,15 @@ public: uint64_t getDead() const { return _dead; } uint64_t getOnHold() const { return _onHold; } uint64_t getOnHoldMax() const { return _onHoldMax; } - uint64_t getLastSyncToken() const { return _lastSyncToken; } + // This might be accessed from other threads than the writer thread. + uint64_t getLastSyncToken() const { return _lastSyncToken.load(std::memory_order_relaxed); } uint64_t getUpdateCount() const { return _updates; } uint64_t getNonIdempotentUpdateCount() const { return _nonIdempotentUpdates; } uint32_t getBitVectors() const { return _bitVectors; } void setNumDocs(uint64_t v) { _numDocs = v; } void incNumDocs() { ++_numDocs; } - void setLastSyncToken(uint64_t v) { _lastSyncToken = v; } + void setLastSyncToken(uint64_t v) { _lastSyncToken.store(v, std::memory_order_relaxed); } void incUpdates(uint64_t v=1) { _updates += v; } void incNonIdempotentUpdates(uint64_t v = 1) { _nonIdempotentUpdates += v; } void incBitVectors() { ++_bitVectors; } @@ -47,7 +51,7 @@ private: uint64_t _unused; uint64_t _onHold; uint64_t _onHoldMax; - uint64_t _lastSyncToken; + std::atomic<uint64_t> _lastSyncToken; uint64_t _updates; uint64_t _nonIdempotentUpdates; uint32_t _bitVectors; diff --git a/searchcore/src/tests/proton/attribute/attribute_test.cpp b/searchcore/src/tests/proton/attribute/attribute_test.cpp index 0c21fabc27a..c101c3e2bd5 100644 --- a/searchcore/src/tests/proton/attribute/attribute_test.cpp +++ b/searchcore/src/tests/proton/attribute/attribute_test.cpp @@ -61,6 +61,8 @@ using proton::test::AttributeUtils; using proton::test::MockAttributeManager; using search::TuneFileAttributes; using search::attribute::BitVectorSearchCache; +using search::attribute::DistanceMetric; +using search::attribute::HnswIndexParams; using search::attribute::IAttributeVector; using search::attribute::ImportedAttributeVector; using search::attribute::ImportedAttributeVectorFactory; @@ -760,6 +762,38 @@ TEST_F(AttributeWriterTest, spreads_write_over_3_write_contexts) putAttributes(*this, {0, 1, 2}); } +AVConfig +get_tensor_config(bool allow_multi_threaded_indexing) +{ + AVConfig cfg(AVBasicType::TENSOR); + cfg.setTensorType(ValueType::from_spec("tensor(x[2])")); + cfg.set_hnsw_index_params(HnswIndexParams(4, 4, DistanceMetric::Euclidean, allow_multi_threaded_indexing)); + return cfg; +} + +TEST_F(AttributeWriterTest, tensor_attributes_using_two_phase_put_are_in_separate_write_contexts) +{ + addAttribute("a1"); + addAttribute({"t1", get_tensor_config(true)}); + addAttribute({"t2", get_tensor_config(true)}); + addAttribute({"t3", get_tensor_config(false)}); + allocAttributeWriter(); + + const auto& ctx = _aw->get_write_contexts(); + EXPECT_EQ(3, ctx.size()); + EXPECT_FALSE(ctx[0].use_two_phase_put()); + EXPECT_EQ(2, ctx[0].getFields().size()); + + EXPECT_TRUE(ctx[1].use_two_phase_put()); + EXPECT_EQ(1, ctx[1].getFields().size()); + EXPECT_EQ("t1", ctx[1].getFields()[0].getAttribute().getName()); + + EXPECT_TRUE(ctx[2].use_two_phase_put()); + EXPECT_EQ(1, ctx[2].getFields().size()); + EXPECT_EQ("t2", ctx[2].getFields()[0].getAttribute().getName()); +} + + ImportedAttributeVector::SP createImportedAttribute(const vespalib::string &name) { diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp index 33b9d162163..07dfbc1eac7 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp @@ -12,26 +12,48 @@ #include <vespa/searchcore/proton/common/attribute_updater.h> #include <vespa/searchlib/attribute/attributevector.hpp> #include <vespa/searchlib/attribute/imported_attribute_vector.h> +#include <vespa/searchlib/tensor/prepare_result.h> #include <vespa/searchlib/common/idestructorcallback.h> #include <vespa/vespalib/stllike/hash_map.hpp> +#include <future> #include <vespa/log/log.h> LOG_SETUP(".proton.attribute.attribute_writer"); using namespace document; using namespace search; + +using ExecutorId = vespalib::ISequencedTaskExecutor::ExecutorId; using search::attribute::ImportedAttributeVector; +using search::tensor::PrepareResult; using vespalib::ISequencedTaskExecutor; -using ExecutorId = vespalib::ISequencedTaskExecutor::ExecutorId; namespace proton { using LidVector = LidVectorContext::LidVector; +namespace { + +bool +use_two_phase_put_for_attribute(const AttributeVector& attr) +{ + const auto& cfg = attr.getConfig(); + if (cfg.basicType() == search::attribute::BasicType::Type::TENSOR && + cfg.hnsw_index_params().has_value() && + cfg.hnsw_index_params().value().allow_multi_threaded_indexing()) + { + return true; + } + return false; +} + +} + AttributeWriter::WriteField::WriteField(AttributeVector &attribute) : _fieldPath(), _attribute(attribute), - _structFieldAttribute(false) + _structFieldAttribute(false), + _use_two_phase_put(use_two_phase_put_for_attribute(attribute)) { const vespalib::string &name = attribute.getName(); _structFieldAttribute = attribute::isStructFieldAttribute(name); @@ -57,11 +79,11 @@ AttributeWriter::WriteField::buildFieldPath(const DocumentType &docType) AttributeWriter::WriteContext::WriteContext(ExecutorId executorId) : _executorId(executorId), _fields(), - _hasStructFieldAttribute(false) + _hasStructFieldAttribute(false), + _use_two_phase_put(false) { } - AttributeWriter::WriteContext::WriteContext(WriteContext &&rhs) noexcept = default; AttributeWriter::WriteContext::~WriteContext() = default; @@ -75,6 +97,13 @@ AttributeWriter::WriteContext::add(AttributeVector &attr) if (_fields.back().isStructFieldAttribute()) { _hasStructFieldAttribute = true; } + if (_fields.back().use_two_phase_put()) { + // Only support for one field per context when this is true. + assert(_fields.size() == 1); + _use_two_phase_put = true; + } else { + assert(!_use_two_phase_put); + } } void @@ -113,6 +142,26 @@ applyPutToAttribute(SerialNum serialNum, const FieldValue::UP &fieldValue, Docum } void +complete_put_to_attribute(SerialNum serial_num, + uint32_t docid, + AttributeVector& attr, + const FieldValue::SP& field_value, + std::unique_ptr<PrepareResult> prepare_result, + bool immediate_commit, + AttributeWriter::OnWriteDoneType) +{ + ensureLidSpace(serial_num, docid, attr); + if (field_value.get()) { + AttributeUpdater::complete_set_value(attr, docid, *field_value, std::move(prepare_result)); + } else { + attr.clearDoc(docid); + } + if (immediate_commit) { + attr.commit(serial_num, serial_num); + } +} + +void applyRemoveToAttribute(SerialNum serialNum, DocumentIdT lid, bool immediateCommit, AttributeVector &attr, AttributeWriter::OnWriteDoneType) { @@ -148,7 +197,6 @@ applyReplayDone(uint32_t docIdLimit, AttributeVector &attr) attr.shrinkLidSpace(); } - void applyHeartBeat(SerialNum serialNum, AttributeVector &attr) { @@ -166,7 +214,6 @@ applyCommit(SerialNum serialNum, AttributeWriter::OnWriteDoneType , AttributeVec } } - void applyCompactLidSpace(uint32_t wantedLidLimit, SerialNum serialNum, AttributeVector &attr) { @@ -208,7 +255,6 @@ struct BatchUpdateTask : public vespalib::Executor::Task { } } - SerialNum _serialNum; DocumentIdT _lid; bool _immediateCommit; @@ -221,6 +267,7 @@ class FieldContext vespalib::string _name; ExecutorId _executorId; AttributeVector *_attr; + bool _use_two_phase_put; public: FieldContext(ISequencedTaskExecutor &writer, AttributeVector *attr); @@ -228,13 +275,14 @@ public: bool operator<(const FieldContext &rhs) const; ExecutorId getExecutorId() const { return _executorId; } AttributeVector *getAttribute() const { return _attr; } + bool use_two_phase_put() const { return _use_two_phase_put; } }; - FieldContext::FieldContext(ISequencedTaskExecutor &writer, AttributeVector *attr) : _name(attr->getName()), _executorId(writer.getExecutorId(attr->getNamePrefix())), - _attr(attr) + _attr(attr), + _use_two_phase_put(use_two_phase_put_for_attribute(*attr)) { } @@ -303,6 +351,101 @@ PutTask::run() } } + +class PreparePutTask : public vespalib::Executor::Task { +private: + const SerialNum _serial_num; + const uint32_t _docid; + AttributeVector& _attr; + FieldValue::SP _field_value; + std::promise<std::unique_ptr<PrepareResult>> _result_promise; + +public: + PreparePutTask(SerialNum serial_num_in, + uint32_t docid_in, + const AttributeWriter::WriteField& field, + std::shared_ptr<DocumentFieldExtractor> field_extractor); + ~PreparePutTask() override; + void run() override; + SerialNum serial_num() const { return _serial_num; } + uint32_t docid() const { return _docid; } + AttributeVector& attr() { return _attr; } + FieldValue::SP field_value() { return _field_value; } + std::future<std::unique_ptr<PrepareResult>> result_future() { + return _result_promise.get_future(); + } +}; + +PreparePutTask::PreparePutTask(SerialNum serial_num_in, + uint32_t docid_in, + const AttributeWriter::WriteField& field, + std::shared_ptr<DocumentFieldExtractor> field_extractor) + : _serial_num(serial_num_in), + _docid(docid_in), + _attr(field.getAttribute()), + _field_value(), + _result_promise() +{ + // Note: No need to store the field extractor as we are not extracting struct fields. + auto value = field_extractor->getFieldValue(field.getFieldPath()); + _field_value.reset(value.release()); +} + +PreparePutTask::~PreparePutTask() = default; + +void +PreparePutTask::run() +{ + if (_attr.getStatus().getLastSyncToken() < _serial_num) { + if (_field_value.get()) { + _result_promise.set_value(AttributeUpdater::prepare_set_value(_attr, _docid, *_field_value)); + } + } +} + +class CompletePutTask : public vespalib::Executor::Task { +private: + const SerialNum _serial_num; + const uint32_t _docid; + AttributeVector& _attr; + FieldValue::SP _field_value; + std::future<std::unique_ptr<PrepareResult>> _result_future; + const bool _immediate_commit; + std::remove_reference_t<AttributeWriter::OnWriteDoneType> _on_write_done; + +public: + CompletePutTask(PreparePutTask& prepare_task, + bool immediate_commit, + AttributeWriter::OnWriteDoneType on_write_done); + ~CompletePutTask() override; + void run() override; +}; + +CompletePutTask::CompletePutTask(PreparePutTask& prepare_task, + bool immediate_commit, + AttributeWriter::OnWriteDoneType on_write_done) + : _serial_num(prepare_task.serial_num()), + _docid(prepare_task.docid()), + _attr(prepare_task.attr()), + _field_value(prepare_task.field_value()), + _result_future(prepare_task.result_future()), + _immediate_commit(immediate_commit), + _on_write_done(on_write_done) +{ +} + +CompletePutTask::~CompletePutTask() = default; + +void +CompletePutTask::run() +{ + if (_attr.getStatus().getLastSyncToken() < _serial_num) { + auto result = _result_future.get(); + complete_put_to_attribute(_serial_num, _docid, _attr, _field_value, std::move(result), + _immediate_commit, _on_write_done); + } +} + class RemoveTask : public vespalib::Executor::Task { const AttributeWriter::WriteContext &_wc; @@ -316,7 +459,6 @@ public: void run() override; }; - RemoveTask::RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, bool immediateCommit, AttributeWriter::OnWriteDoneType onWriteDone) : _wc(wc), _serialNum(serialNum), @@ -419,13 +561,22 @@ AttributeWriter::setupWriteContexts() fieldContexts.emplace_back(_attributeFieldWriter, attr); } std::sort(fieldContexts.begin(), fieldContexts.end()); - for (auto &fc : fieldContexts) { + for (const auto& fc : fieldContexts) { + if (fc.use_two_phase_put()) { + continue; + } if (_writeContexts.empty() || (_writeContexts.back().getExecutorId() != fc.getExecutorId())) { _writeContexts.emplace_back(fc.getExecutorId()); } _writeContexts.back().add(*fc.getAttribute()); } + for (const auto& fc : fieldContexts) { + if (fc.use_two_phase_put()) { + _writeContexts.emplace_back(fc.getExecutorId()); + _writeContexts.back().add(*fc.getAttribute()); + } + } for (const auto &wc : _writeContexts) { if (wc.hasStructFieldAttribute()) { _hasStructFieldAttribute = true; @@ -452,9 +603,19 @@ AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentI } auto extractor = std::make_shared<DocumentFieldExtractor>(doc); for (const auto &wc : _writeContexts) { - if (allAttributes || wc.hasStructFieldAttribute()) { - auto putTask = std::make_unique<PutTask>(wc, serialNum, extractor, lid, immediateCommit, allAttributes, onWriteDone); - _attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(putTask)); + if (wc.use_two_phase_put()) { + assert(wc.getFields().size() == 1); + auto prepare_task = std::make_unique<PreparePutTask>(serialNum, lid, wc.getFields()[0], extractor); + auto complete_task = std::make_unique<CompletePutTask>(*prepare_task, immediateCommit, onWriteDone); + // We use the local docid to create an executor id to round-robin between the threads. + _attributeFieldWriter.executeTask(_attributeFieldWriter.getExecutorId(lid), std::move(prepare_task)); + _attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(complete_task)); + } else { + if (allAttributes || wc.hasStructFieldAttribute()) { + auto putTask = std::make_unique<PutTask>(wc, serialNum, extractor, lid, immediateCommit, allAttributes, + onWriteDone); + _attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(putTask)); + } } } } diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h index 4a9726dd113..726379220e3 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h +++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h @@ -19,20 +19,23 @@ namespace proton { class AttributeWriter : public IAttributeWriter { private: - typedef search::AttributeVector AttributeVector; - typedef document::FieldPath FieldPath; - typedef document::DataType DataType; - typedef document::DocumentType DocumentType; - typedef document::FieldValue FieldValue; + using AttributeVector = search::AttributeVector; + using FieldPath = document::FieldPath; + using DataType = document::DataType; + using DocumentType = document::DocumentType; + using FieldValue = document::FieldValue; const IAttributeManager::SP _mgr; vespalib::ISequencedTaskExecutor &_attributeFieldWriter; using ExecutorId = vespalib::ISequencedTaskExecutor::ExecutorId; public: - class WriteField - { + /** + * Represents an attribute vector for a field and details about how to write to it. + */ + class WriteField { FieldPath _fieldPath; AttributeVector &_attribute; bool _structFieldAttribute; // in array/map of struct + bool _use_two_phase_put; public: WriteField(AttributeVector &attribute); ~WriteField(); @@ -40,12 +43,18 @@ public: const FieldPath &getFieldPath() const { return _fieldPath; } void buildFieldPath(const DocumentType &docType); bool isStructFieldAttribute() const { return _structFieldAttribute; } + bool use_two_phase_put() const { return _use_two_phase_put; } }; - class WriteContext - { + + /** + * Represents a set of fields (as attributes) that are handled by the same write thread. + */ + class WriteContext { ExecutorId _executorId; std::vector<WriteField> _fields; bool _hasStructFieldAttribute; + // When this is true, the context only contains a single field. + bool _use_two_phase_put; public: WriteContext(ExecutorId executorId); WriteContext(WriteContext &&rhs) noexcept; @@ -56,6 +65,7 @@ public: ExecutorId getExecutorId() const { return _executorId; } const std::vector<WriteField> &getFields() const { return _fields; } bool hasStructFieldAttribute() const { return _hasStructFieldAttribute; } + bool use_two_phase_put() const { return _use_two_phase_put; } }; private: using AttrWithId = std::pair<search::AttributeVector *, ExecutorId>; @@ -103,6 +113,11 @@ public: void onReplayDone(uint32_t docIdLimit) override; bool hasStructFieldAttribute() const override; + + // Should only be used for unit testing. + const std::vector<WriteContext>& get_write_contexts() const { + return _writeContexts; + } }; } // namespace proton diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp index 8fd47c17acb..d7cf6caff28 100644 --- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp +++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp @@ -31,6 +31,7 @@ LOG_SETUP(".proton.common.attribute_updater"); using namespace document; using vespalib::make_string; +using search::tensor::PrepareResult; using search::tensor::TensorAttribute; using search::attribute::ReferenceAttribute; @@ -471,27 +472,33 @@ AttributeUpdater::updateValue(StringAttribute & vec, uint32_t lid, const FieldVa } } +namespace { + +template <typename ExpFieldValueType> void -AttributeUpdater::updateValue(PredicateAttribute &vec, uint32_t lid, const FieldValue &val) +validate_field_value_type(const FieldValue& val, const vespalib::string& attr_type, const vespalib::string& value_type) { - if (!val.inherits(PredicateFieldValue::classId)) { + if (!val.inherits(ExpFieldValueType::classId)) { throw UpdateException( - make_string("PredicateAttribute must be updated with " - "PredicateFieldValues.")); + make_string("%s must be updated with %s, but was '%s'", + attr_type.c_str(), value_type.c_str(), val.toString(false).c_str())); } +} + +} + +void +AttributeUpdater::updateValue(PredicateAttribute &vec, uint32_t lid, const FieldValue &val) +{ + validate_field_value_type<PredicateFieldValue>(val, "PredicateAttribute", "PredicateFieldValue"); vec.updateValue(lid, static_cast<const PredicateFieldValue &>(val)); } void AttributeUpdater::updateValue(TensorAttribute &vec, uint32_t lid, const FieldValue &val) { - if (!val.inherits(TensorFieldValue::classId)) { - throw UpdateException( - make_string("TensorAttribute must be updated with " - "TensorFieldValues.")); - } - const auto &tensor = static_cast<const TensorFieldValue &>(val). - getAsTensorPtr(); + validate_field_value_type<TensorFieldValue>(val, "TensorAttribute", "TensorFieldValue"); + const auto &tensor = static_cast<const TensorFieldValue &>(val).getAsTensorPtr(); if (tensor) { vec.setTensor(lid, *tensor); } else { @@ -506,7 +513,7 @@ AttributeUpdater::updateValue(ReferenceAttribute &vec, uint32_t lid, const Field vec.clearDoc(lid); throw UpdateException( make_string("ReferenceAttribute must be updated with " - "ReferenceFieldValues.")); + "ReferenceFieldValue, but was '%s'", val.toString(false).c_str())); } const auto &reffv = static_cast<const ReferenceFieldValue &>(val); if (reffv.hasValidDocumentId()) { @@ -516,4 +523,57 @@ AttributeUpdater::updateValue(ReferenceAttribute &vec, uint32_t lid, const Field } } +namespace { + +void +validate_tensor_attribute_type(AttributeVector& attr) +{ + const auto& info = attr.getClass(); + if (!info.inherits(TensorAttribute::classId)) { + throw UpdateException( + make_string("Expected attribute vector '%s' to be a TensorAttribute, but was '%s'", + attr.getName().c_str(), info.name())); + } +} + +std::unique_ptr<PrepareResult> +prepare_set_tensor(TensorAttribute& attr, uint32_t docid, const FieldValue& val) +{ + validate_field_value_type<TensorFieldValue>(val, "TensorAttribute", "TensorFieldValue"); + const auto& tensor = static_cast<const TensorFieldValue&>(val).getAsTensorPtr(); + if (tensor) { + return attr.prepare_set_tensor(docid, *tensor); + } + return std::unique_ptr<PrepareResult>(); +} + +void +complete_set_tensor(TensorAttribute& attr, uint32_t docid, const FieldValue& val, std::unique_ptr<PrepareResult> prepare_result) +{ + validate_field_value_type<TensorFieldValue>(val, "TensorAttribute", "TensorFieldValue"); + const auto& tensor = static_cast<const TensorFieldValue&>(val).getAsTensorPtr(); + if (tensor) { + attr.complete_set_tensor(docid, *tensor, std::move(prepare_result)); + } else { + attr.clearDoc(docid); + } +} + +} + +std::unique_ptr<PrepareResult> +AttributeUpdater::prepare_set_value(AttributeVector& attr, uint32_t docid, const FieldValue& val) +{ + validate_tensor_attribute_type(attr); + return prepare_set_tensor(static_cast<TensorAttribute&>(attr), docid, val); +} + +void +AttributeUpdater::complete_set_value(AttributeVector& attr, uint32_t docid, const FieldValue& val, + std::unique_ptr<PrepareResult> prepare_result) +{ + validate_tensor_attribute_type(attr); + complete_set_tensor(static_cast<TensorAttribute&>(attr), docid, val, std::move(prepare_result)); +} + } // namespace search diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.h b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.h index 01be6299692..32d14f6dd5a 100644 --- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.h +++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.h @@ -10,7 +10,10 @@ namespace search { class PredicateAttribute; -namespace tensor { class TensorAttribute; } +namespace tensor { + class PrepareResult; + class TensorAttribute; +} namespace attribute {class ReferenceAttribute; } VESPA_DEFINE_EXCEPTION(UpdateException, vespalib::Exception); @@ -20,14 +23,18 @@ VESPA_DEFINE_EXCEPTION(UpdateException, vespalib::Exception); */ class AttributeUpdater { using Field = document::Field; - using FieldValue = document::FieldValue; using FieldUpdate = document::FieldUpdate; + using FieldValue = document::FieldValue; using ValueUpdate = document::ValueUpdate; public: static void handleUpdate(AttributeVector & vec, uint32_t lid, const FieldUpdate & upd); static void handleValue(AttributeVector & vec, uint32_t lid, const FieldValue & val); + static std::unique_ptr<tensor::PrepareResult> prepare_set_value(AttributeVector& attr, uint32_t docid, const FieldValue& val); + static void complete_set_value(AttributeVector& attr, uint32_t docid, const FieldValue& val, + std::unique_ptr<tensor::PrepareResult> prepare_result); + private: template <typename V> static void handleUpdate(V & vec, uint32_t lid, const ValueUpdate & upd); diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp index f8db11ae9d8..6cf4f6d2689 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp @@ -263,7 +263,7 @@ TensorAttribute::prepare_set_tensor(DocId docid, const Tensor& tensor) const void TensorAttribute::complete_set_tensor(DocId docid, const Tensor& tensor, - std::future<std::unique_ptr<PrepareResult>> prepare_result) + std::unique_ptr<PrepareResult> prepare_result) { (void) docid; (void) tensor; diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h index f752b9f7f2e..8380e485172 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h @@ -7,7 +7,6 @@ #include "tensor_store.h" #include <vespa/searchlib/attribute/not_implemented_attribute.h> #include <vespa/vespalib/util/rcuvector.h> -#include <future> namespace search::tensor { @@ -66,9 +65,9 @@ public: * Performs the complete step in a two-phase operation to set a tensor for a document. * * This function is only called by the attribute writer thread. - * It must wait for the result from the prepare step (via the future) before it does the modifying changes. + * It uses the result from the prepare step to do the modifying changes. */ - virtual void complete_set_tensor(DocId docid, const Tensor& tensor, std::future<std::unique_ptr<PrepareResult>> prepare_result); + virtual void complete_set_tensor(DocId docid, const Tensor& tensor, std::unique_ptr<PrepareResult> prepare_result); virtual void compactWorst() = 0; }; |