summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2020-06-12 14:07:49 +0000
committerGeir Storli <geirst@verizonmedia.com>2020-06-15 08:05:53 +0000
commit8c0fa90a06966d7f6411e915b7c0d6906c53b130 (patch)
treebaa17856f237d4aff8022fee8272171c41979884
parenta600b2141c93c668d0f6aa4c2b5bc680e8eaf380 (diff)
Implement initial support for two-phase puts in attribute writer.
This is only turned on for tensor attributes with a hnsw index that allows multi-threaded indexing.
-rw-r--r--searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h11
-rw-r--r--searchcommon/src/vespa/searchcommon/attribute/status.cpp36
-rw-r--r--searchcommon/src/vespa/searchcommon/attribute/status.h10
-rw-r--r--searchcore/src/tests/proton/attribute/attribute_test.cpp34
-rw-r--r--searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp189
-rw-r--r--searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h33
-rw-r--r--searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp84
-rw-r--r--searchcore/src/vespa/searchcore/proton/common/attribute_updater.h11
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_attribute.h5
10 files changed, 368 insertions, 47 deletions
diff --git a/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h b/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h
index 3e3683ce60f..c8b196023d6 100644
--- a/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h
+++ b/searchcommon/src/vespa/searchcommon/attribute/hnsw_index_params.h
@@ -16,24 +16,29 @@ private:
uint32_t _neighbors_to_explore_at_insert;
// This is always the same as in the attribute config, and is duplicated here to simplify usage.
DistanceMetric _distance_metric;
+ bool _allow_multi_threaded_indexing;
public:
HnswIndexParams(uint32_t max_links_per_node_in,
uint32_t neighbors_to_explore_at_insert_in,
- DistanceMetric distance_metric_in)
+ DistanceMetric distance_metric_in,
+ bool allow_multi_threaded_indexing_in = false)
: _max_links_per_node(max_links_per_node_in),
_neighbors_to_explore_at_insert(neighbors_to_explore_at_insert_in),
- _distance_metric(distance_metric_in)
+ _distance_metric(distance_metric_in),
+ _allow_multi_threaded_indexing(allow_multi_threaded_indexing_in)
{}
uint32_t max_links_per_node() const { return _max_links_per_node; }
uint32_t neighbors_to_explore_at_insert() const { return _neighbors_to_explore_at_insert; }
DistanceMetric distance_metric() const { return _distance_metric; }
+ bool allow_multi_threaded_indexing() const { return _allow_multi_threaded_indexing; }
bool operator==(const HnswIndexParams& rhs) const {
return (_max_links_per_node == rhs._max_links_per_node &&
_neighbors_to_explore_at_insert == rhs._neighbors_to_explore_at_insert &&
- _distance_metric == rhs._distance_metric);
+ _distance_metric == rhs._distance_metric &&
+ _allow_multi_threaded_indexing == rhs._allow_multi_threaded_indexing);
}
};
diff --git a/searchcommon/src/vespa/searchcommon/attribute/status.cpp b/searchcommon/src/vespa/searchcommon/attribute/status.cpp
index da13548ec2e..f2bb49c348a 100644
--- a/searchcommon/src/vespa/searchcommon/attribute/status.cpp
+++ b/searchcommon/src/vespa/searchcommon/attribute/status.cpp
@@ -20,6 +20,42 @@ Status::Status()
{
}
+Status::Status(const Status& rhs)
+ : _numDocs(rhs._numDocs),
+ _numValues(rhs._numValues),
+ _numUniqueValues(rhs._numUniqueValues),
+ _allocated(rhs._allocated),
+ _used(rhs._used),
+ _dead(rhs._dead),
+ _unused(rhs._unused),
+ _onHold(rhs._onHold),
+ _onHoldMax(rhs._onHoldMax),
+ _lastSyncToken(rhs.getLastSyncToken()),
+ _updates(rhs._updates),
+ _nonIdempotentUpdates(rhs._nonIdempotentUpdates),
+ _bitVectors(rhs._bitVectors)
+{
+}
+
+Status&
+Status::operator=(const Status& rhs)
+{
+ _numDocs = rhs._numDocs;
+ _numValues = rhs._numValues;
+ _numUniqueValues = rhs._numUniqueValues;
+ _allocated = rhs._allocated;
+ _used = rhs._used;
+ _dead = rhs._dead;
+ _unused = rhs._unused;
+ _onHold = rhs._onHold;
+ _onHoldMax = rhs._onHoldMax;
+ setLastSyncToken(rhs.getLastSyncToken());
+ _updates = rhs._updates;
+ _nonIdempotentUpdates = rhs._nonIdempotentUpdates;
+ _bitVectors = rhs._bitVectors;
+ return *this;
+}
+
vespalib::string
Status::createName(vespalib::stringref index, vespalib::stringref attr)
{
diff --git a/searchcommon/src/vespa/searchcommon/attribute/status.h b/searchcommon/src/vespa/searchcommon/attribute/status.h
index 888355b3f58..a624309da65 100644
--- a/searchcommon/src/vespa/searchcommon/attribute/status.h
+++ b/searchcommon/src/vespa/searchcommon/attribute/status.h
@@ -3,6 +3,7 @@
#pragma once
#include <vespa/vespalib/stllike/string.h>
+#include <atomic>
namespace search::attribute {
@@ -10,6 +11,8 @@ class Status
{
public:
Status();
+ Status(const Status& rhs);
+ Status& operator=(const Status& rhs);
void updateStatistics(uint64_t numValues, uint64_t numUniqueValue, uint64_t allocated,
uint64_t used, uint64_t dead, uint64_t onHold);
@@ -22,14 +25,15 @@ public:
uint64_t getDead() const { return _dead; }
uint64_t getOnHold() const { return _onHold; }
uint64_t getOnHoldMax() const { return _onHoldMax; }
- uint64_t getLastSyncToken() const { return _lastSyncToken; }
+ // This might be accessed from other threads than the writer thread.
+ uint64_t getLastSyncToken() const { return _lastSyncToken.load(std::memory_order_relaxed); }
uint64_t getUpdateCount() const { return _updates; }
uint64_t getNonIdempotentUpdateCount() const { return _nonIdempotentUpdates; }
uint32_t getBitVectors() const { return _bitVectors; }
void setNumDocs(uint64_t v) { _numDocs = v; }
void incNumDocs() { ++_numDocs; }
- void setLastSyncToken(uint64_t v) { _lastSyncToken = v; }
+ void setLastSyncToken(uint64_t v) { _lastSyncToken.store(v, std::memory_order_relaxed); }
void incUpdates(uint64_t v=1) { _updates += v; }
void incNonIdempotentUpdates(uint64_t v = 1) { _nonIdempotentUpdates += v; }
void incBitVectors() { ++_bitVectors; }
@@ -47,7 +51,7 @@ private:
uint64_t _unused;
uint64_t _onHold;
uint64_t _onHoldMax;
- uint64_t _lastSyncToken;
+ std::atomic<uint64_t> _lastSyncToken;
uint64_t _updates;
uint64_t _nonIdempotentUpdates;
uint32_t _bitVectors;
diff --git a/searchcore/src/tests/proton/attribute/attribute_test.cpp b/searchcore/src/tests/proton/attribute/attribute_test.cpp
index 0c21fabc27a..c101c3e2bd5 100644
--- a/searchcore/src/tests/proton/attribute/attribute_test.cpp
+++ b/searchcore/src/tests/proton/attribute/attribute_test.cpp
@@ -61,6 +61,8 @@ using proton::test::AttributeUtils;
using proton::test::MockAttributeManager;
using search::TuneFileAttributes;
using search::attribute::BitVectorSearchCache;
+using search::attribute::DistanceMetric;
+using search::attribute::HnswIndexParams;
using search::attribute::IAttributeVector;
using search::attribute::ImportedAttributeVector;
using search::attribute::ImportedAttributeVectorFactory;
@@ -760,6 +762,38 @@ TEST_F(AttributeWriterTest, spreads_write_over_3_write_contexts)
putAttributes(*this, {0, 1, 2});
}
+AVConfig
+get_tensor_config(bool allow_multi_threaded_indexing)
+{
+ AVConfig cfg(AVBasicType::TENSOR);
+ cfg.setTensorType(ValueType::from_spec("tensor(x[2])"));
+ cfg.set_hnsw_index_params(HnswIndexParams(4, 4, DistanceMetric::Euclidean, allow_multi_threaded_indexing));
+ return cfg;
+}
+
+TEST_F(AttributeWriterTest, tensor_attributes_using_two_phase_put_are_in_separate_write_contexts)
+{
+ addAttribute("a1");
+ addAttribute({"t1", get_tensor_config(true)});
+ addAttribute({"t2", get_tensor_config(true)});
+ addAttribute({"t3", get_tensor_config(false)});
+ allocAttributeWriter();
+
+ const auto& ctx = _aw->get_write_contexts();
+ EXPECT_EQ(3, ctx.size());
+ EXPECT_FALSE(ctx[0].use_two_phase_put());
+ EXPECT_EQ(2, ctx[0].getFields().size());
+
+ EXPECT_TRUE(ctx[1].use_two_phase_put());
+ EXPECT_EQ(1, ctx[1].getFields().size());
+ EXPECT_EQ("t1", ctx[1].getFields()[0].getAttribute().getName());
+
+ EXPECT_TRUE(ctx[2].use_two_phase_put());
+ EXPECT_EQ(1, ctx[2].getFields().size());
+ EXPECT_EQ("t2", ctx[2].getFields()[0].getAttribute().getName());
+}
+
+
ImportedAttributeVector::SP
createImportedAttribute(const vespalib::string &name)
{
diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp
index 33b9d162163..07dfbc1eac7 100644
--- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp
+++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp
@@ -12,26 +12,48 @@
#include <vespa/searchcore/proton/common/attribute_updater.h>
#include <vespa/searchlib/attribute/attributevector.hpp>
#include <vespa/searchlib/attribute/imported_attribute_vector.h>
+#include <vespa/searchlib/tensor/prepare_result.h>
#include <vespa/searchlib/common/idestructorcallback.h>
#include <vespa/vespalib/stllike/hash_map.hpp>
+#include <future>
#include <vespa/log/log.h>
LOG_SETUP(".proton.attribute.attribute_writer");
using namespace document;
using namespace search;
+
+using ExecutorId = vespalib::ISequencedTaskExecutor::ExecutorId;
using search::attribute::ImportedAttributeVector;
+using search::tensor::PrepareResult;
using vespalib::ISequencedTaskExecutor;
-using ExecutorId = vespalib::ISequencedTaskExecutor::ExecutorId;
namespace proton {
using LidVector = LidVectorContext::LidVector;
+namespace {
+
+bool
+use_two_phase_put_for_attribute(const AttributeVector& attr)
+{
+ const auto& cfg = attr.getConfig();
+ if (cfg.basicType() == search::attribute::BasicType::Type::TENSOR &&
+ cfg.hnsw_index_params().has_value() &&
+ cfg.hnsw_index_params().value().allow_multi_threaded_indexing())
+ {
+ return true;
+ }
+ return false;
+}
+
+}
+
AttributeWriter::WriteField::WriteField(AttributeVector &attribute)
: _fieldPath(),
_attribute(attribute),
- _structFieldAttribute(false)
+ _structFieldAttribute(false),
+ _use_two_phase_put(use_two_phase_put_for_attribute(attribute))
{
const vespalib::string &name = attribute.getName();
_structFieldAttribute = attribute::isStructFieldAttribute(name);
@@ -57,11 +79,11 @@ AttributeWriter::WriteField::buildFieldPath(const DocumentType &docType)
AttributeWriter::WriteContext::WriteContext(ExecutorId executorId)
: _executorId(executorId),
_fields(),
- _hasStructFieldAttribute(false)
+ _hasStructFieldAttribute(false),
+ _use_two_phase_put(false)
{
}
-
AttributeWriter::WriteContext::WriteContext(WriteContext &&rhs) noexcept = default;
AttributeWriter::WriteContext::~WriteContext() = default;
@@ -75,6 +97,13 @@ AttributeWriter::WriteContext::add(AttributeVector &attr)
if (_fields.back().isStructFieldAttribute()) {
_hasStructFieldAttribute = true;
}
+ if (_fields.back().use_two_phase_put()) {
+ // Only support for one field per context when this is true.
+ assert(_fields.size() == 1);
+ _use_two_phase_put = true;
+ } else {
+ assert(!_use_two_phase_put);
+ }
}
void
@@ -113,6 +142,26 @@ applyPutToAttribute(SerialNum serialNum, const FieldValue::UP &fieldValue, Docum
}
void
+complete_put_to_attribute(SerialNum serial_num,
+ uint32_t docid,
+ AttributeVector& attr,
+ const FieldValue::SP& field_value,
+ std::unique_ptr<PrepareResult> prepare_result,
+ bool immediate_commit,
+ AttributeWriter::OnWriteDoneType)
+{
+ ensureLidSpace(serial_num, docid, attr);
+ if (field_value.get()) {
+ AttributeUpdater::complete_set_value(attr, docid, *field_value, std::move(prepare_result));
+ } else {
+ attr.clearDoc(docid);
+ }
+ if (immediate_commit) {
+ attr.commit(serial_num, serial_num);
+ }
+}
+
+void
applyRemoveToAttribute(SerialNum serialNum, DocumentIdT lid, bool immediateCommit,
AttributeVector &attr, AttributeWriter::OnWriteDoneType)
{
@@ -148,7 +197,6 @@ applyReplayDone(uint32_t docIdLimit, AttributeVector &attr)
attr.shrinkLidSpace();
}
-
void
applyHeartBeat(SerialNum serialNum, AttributeVector &attr)
{
@@ -166,7 +214,6 @@ applyCommit(SerialNum serialNum, AttributeWriter::OnWriteDoneType , AttributeVec
}
}
-
void
applyCompactLidSpace(uint32_t wantedLidLimit, SerialNum serialNum, AttributeVector &attr)
{
@@ -208,7 +255,6 @@ struct BatchUpdateTask : public vespalib::Executor::Task {
}
}
-
SerialNum _serialNum;
DocumentIdT _lid;
bool _immediateCommit;
@@ -221,6 +267,7 @@ class FieldContext
vespalib::string _name;
ExecutorId _executorId;
AttributeVector *_attr;
+ bool _use_two_phase_put;
public:
FieldContext(ISequencedTaskExecutor &writer, AttributeVector *attr);
@@ -228,13 +275,14 @@ public:
bool operator<(const FieldContext &rhs) const;
ExecutorId getExecutorId() const { return _executorId; }
AttributeVector *getAttribute() const { return _attr; }
+ bool use_two_phase_put() const { return _use_two_phase_put; }
};
-
FieldContext::FieldContext(ISequencedTaskExecutor &writer, AttributeVector *attr)
: _name(attr->getName()),
_executorId(writer.getExecutorId(attr->getNamePrefix())),
- _attr(attr)
+ _attr(attr),
+ _use_two_phase_put(use_two_phase_put_for_attribute(*attr))
{
}
@@ -303,6 +351,101 @@ PutTask::run()
}
}
+
+class PreparePutTask : public vespalib::Executor::Task {
+private:
+ const SerialNum _serial_num;
+ const uint32_t _docid;
+ AttributeVector& _attr;
+ FieldValue::SP _field_value;
+ std::promise<std::unique_ptr<PrepareResult>> _result_promise;
+
+public:
+ PreparePutTask(SerialNum serial_num_in,
+ uint32_t docid_in,
+ const AttributeWriter::WriteField& field,
+ std::shared_ptr<DocumentFieldExtractor> field_extractor);
+ ~PreparePutTask() override;
+ void run() override;
+ SerialNum serial_num() const { return _serial_num; }
+ uint32_t docid() const { return _docid; }
+ AttributeVector& attr() { return _attr; }
+ FieldValue::SP field_value() { return _field_value; }
+ std::future<std::unique_ptr<PrepareResult>> result_future() {
+ return _result_promise.get_future();
+ }
+};
+
+PreparePutTask::PreparePutTask(SerialNum serial_num_in,
+ uint32_t docid_in,
+ const AttributeWriter::WriteField& field,
+ std::shared_ptr<DocumentFieldExtractor> field_extractor)
+ : _serial_num(serial_num_in),
+ _docid(docid_in),
+ _attr(field.getAttribute()),
+ _field_value(),
+ _result_promise()
+{
+ // Note: No need to store the field extractor as we are not extracting struct fields.
+ auto value = field_extractor->getFieldValue(field.getFieldPath());
+ _field_value.reset(value.release());
+}
+
+PreparePutTask::~PreparePutTask() = default;
+
+void
+PreparePutTask::run()
+{
+ if (_attr.getStatus().getLastSyncToken() < _serial_num) {
+ if (_field_value.get()) {
+ _result_promise.set_value(AttributeUpdater::prepare_set_value(_attr, _docid, *_field_value));
+ }
+ }
+}
+
+class CompletePutTask : public vespalib::Executor::Task {
+private:
+ const SerialNum _serial_num;
+ const uint32_t _docid;
+ AttributeVector& _attr;
+ FieldValue::SP _field_value;
+ std::future<std::unique_ptr<PrepareResult>> _result_future;
+ const bool _immediate_commit;
+ std::remove_reference_t<AttributeWriter::OnWriteDoneType> _on_write_done;
+
+public:
+ CompletePutTask(PreparePutTask& prepare_task,
+ bool immediate_commit,
+ AttributeWriter::OnWriteDoneType on_write_done);
+ ~CompletePutTask() override;
+ void run() override;
+};
+
+CompletePutTask::CompletePutTask(PreparePutTask& prepare_task,
+ bool immediate_commit,
+ AttributeWriter::OnWriteDoneType on_write_done)
+ : _serial_num(prepare_task.serial_num()),
+ _docid(prepare_task.docid()),
+ _attr(prepare_task.attr()),
+ _field_value(prepare_task.field_value()),
+ _result_future(prepare_task.result_future()),
+ _immediate_commit(immediate_commit),
+ _on_write_done(on_write_done)
+{
+}
+
+CompletePutTask::~CompletePutTask() = default;
+
+void
+CompletePutTask::run()
+{
+ if (_attr.getStatus().getLastSyncToken() < _serial_num) {
+ auto result = _result_future.get();
+ complete_put_to_attribute(_serial_num, _docid, _attr, _field_value, std::move(result),
+ _immediate_commit, _on_write_done);
+ }
+}
+
class RemoveTask : public vespalib::Executor::Task
{
const AttributeWriter::WriteContext &_wc;
@@ -316,7 +459,6 @@ public:
void run() override;
};
-
RemoveTask::RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, bool immediateCommit, AttributeWriter::OnWriteDoneType onWriteDone)
: _wc(wc),
_serialNum(serialNum),
@@ -419,13 +561,22 @@ AttributeWriter::setupWriteContexts()
fieldContexts.emplace_back(_attributeFieldWriter, attr);
}
std::sort(fieldContexts.begin(), fieldContexts.end());
- for (auto &fc : fieldContexts) {
+ for (const auto& fc : fieldContexts) {
+ if (fc.use_two_phase_put()) {
+ continue;
+ }
if (_writeContexts.empty() ||
(_writeContexts.back().getExecutorId() != fc.getExecutorId())) {
_writeContexts.emplace_back(fc.getExecutorId());
}
_writeContexts.back().add(*fc.getAttribute());
}
+ for (const auto& fc : fieldContexts) {
+ if (fc.use_two_phase_put()) {
+ _writeContexts.emplace_back(fc.getExecutorId());
+ _writeContexts.back().add(*fc.getAttribute());
+ }
+ }
for (const auto &wc : _writeContexts) {
if (wc.hasStructFieldAttribute()) {
_hasStructFieldAttribute = true;
@@ -452,9 +603,19 @@ AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentI
}
auto extractor = std::make_shared<DocumentFieldExtractor>(doc);
for (const auto &wc : _writeContexts) {
- if (allAttributes || wc.hasStructFieldAttribute()) {
- auto putTask = std::make_unique<PutTask>(wc, serialNum, extractor, lid, immediateCommit, allAttributes, onWriteDone);
- _attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(putTask));
+ if (wc.use_two_phase_put()) {
+ assert(wc.getFields().size() == 1);
+ auto prepare_task = std::make_unique<PreparePutTask>(serialNum, lid, wc.getFields()[0], extractor);
+ auto complete_task = std::make_unique<CompletePutTask>(*prepare_task, immediateCommit, onWriteDone);
+ // We use the local docid to create an executor id to round-robin between the threads.
+ _attributeFieldWriter.executeTask(_attributeFieldWriter.getExecutorId(lid), std::move(prepare_task));
+ _attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(complete_task));
+ } else {
+ if (allAttributes || wc.hasStructFieldAttribute()) {
+ auto putTask = std::make_unique<PutTask>(wc, serialNum, extractor, lid, immediateCommit, allAttributes,
+ onWriteDone);
+ _attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(putTask));
+ }
}
}
}
diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h
index 4a9726dd113..726379220e3 100644
--- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h
+++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h
@@ -19,20 +19,23 @@ namespace proton {
class AttributeWriter : public IAttributeWriter
{
private:
- typedef search::AttributeVector AttributeVector;
- typedef document::FieldPath FieldPath;
- typedef document::DataType DataType;
- typedef document::DocumentType DocumentType;
- typedef document::FieldValue FieldValue;
+ using AttributeVector = search::AttributeVector;
+ using FieldPath = document::FieldPath;
+ using DataType = document::DataType;
+ using DocumentType = document::DocumentType;
+ using FieldValue = document::FieldValue;
const IAttributeManager::SP _mgr;
vespalib::ISequencedTaskExecutor &_attributeFieldWriter;
using ExecutorId = vespalib::ISequencedTaskExecutor::ExecutorId;
public:
- class WriteField
- {
+ /**
+ * Represents an attribute vector for a field and details about how to write to it.
+ */
+ class WriteField {
FieldPath _fieldPath;
AttributeVector &_attribute;
bool _structFieldAttribute; // in array/map of struct
+ bool _use_two_phase_put;
public:
WriteField(AttributeVector &attribute);
~WriteField();
@@ -40,12 +43,18 @@ public:
const FieldPath &getFieldPath() const { return _fieldPath; }
void buildFieldPath(const DocumentType &docType);
bool isStructFieldAttribute() const { return _structFieldAttribute; }
+ bool use_two_phase_put() const { return _use_two_phase_put; }
};
- class WriteContext
- {
+
+ /**
+ * Represents a set of fields (as attributes) that are handled by the same write thread.
+ */
+ class WriteContext {
ExecutorId _executorId;
std::vector<WriteField> _fields;
bool _hasStructFieldAttribute;
+ // When this is true, the context only contains a single field.
+ bool _use_two_phase_put;
public:
WriteContext(ExecutorId executorId);
WriteContext(WriteContext &&rhs) noexcept;
@@ -56,6 +65,7 @@ public:
ExecutorId getExecutorId() const { return _executorId; }
const std::vector<WriteField> &getFields() const { return _fields; }
bool hasStructFieldAttribute() const { return _hasStructFieldAttribute; }
+ bool use_two_phase_put() const { return _use_two_phase_put; }
};
private:
using AttrWithId = std::pair<search::AttributeVector *, ExecutorId>;
@@ -103,6 +113,11 @@ public:
void onReplayDone(uint32_t docIdLimit) override;
bool hasStructFieldAttribute() const override;
+
+ // Should only be used for unit testing.
+ const std::vector<WriteContext>& get_write_contexts() const {
+ return _writeContexts;
+ }
};
} // namespace proton
diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
index 8fd47c17acb..d7cf6caff28 100644
--- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
+++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
@@ -31,6 +31,7 @@ LOG_SETUP(".proton.common.attribute_updater");
using namespace document;
using vespalib::make_string;
+using search::tensor::PrepareResult;
using search::tensor::TensorAttribute;
using search::attribute::ReferenceAttribute;
@@ -471,27 +472,33 @@ AttributeUpdater::updateValue(StringAttribute & vec, uint32_t lid, const FieldVa
}
}
+namespace {
+
+template <typename ExpFieldValueType>
void
-AttributeUpdater::updateValue(PredicateAttribute &vec, uint32_t lid, const FieldValue &val)
+validate_field_value_type(const FieldValue& val, const vespalib::string& attr_type, const vespalib::string& value_type)
{
- if (!val.inherits(PredicateFieldValue::classId)) {
+ if (!val.inherits(ExpFieldValueType::classId)) {
throw UpdateException(
- make_string("PredicateAttribute must be updated with "
- "PredicateFieldValues."));
+ make_string("%s must be updated with %s, but was '%s'",
+ attr_type.c_str(), value_type.c_str(), val.toString(false).c_str()));
}
+}
+
+}
+
+void
+AttributeUpdater::updateValue(PredicateAttribute &vec, uint32_t lid, const FieldValue &val)
+{
+ validate_field_value_type<PredicateFieldValue>(val, "PredicateAttribute", "PredicateFieldValue");
vec.updateValue(lid, static_cast<const PredicateFieldValue &>(val));
}
void
AttributeUpdater::updateValue(TensorAttribute &vec, uint32_t lid, const FieldValue &val)
{
- if (!val.inherits(TensorFieldValue::classId)) {
- throw UpdateException(
- make_string("TensorAttribute must be updated with "
- "TensorFieldValues."));
- }
- const auto &tensor = static_cast<const TensorFieldValue &>(val).
- getAsTensorPtr();
+ validate_field_value_type<TensorFieldValue>(val, "TensorAttribute", "TensorFieldValue");
+ const auto &tensor = static_cast<const TensorFieldValue &>(val).getAsTensorPtr();
if (tensor) {
vec.setTensor(lid, *tensor);
} else {
@@ -506,7 +513,7 @@ AttributeUpdater::updateValue(ReferenceAttribute &vec, uint32_t lid, const Field
vec.clearDoc(lid);
throw UpdateException(
make_string("ReferenceAttribute must be updated with "
- "ReferenceFieldValues."));
+ "ReferenceFieldValue, but was '%s'", val.toString(false).c_str()));
}
const auto &reffv = static_cast<const ReferenceFieldValue &>(val);
if (reffv.hasValidDocumentId()) {
@@ -516,4 +523,57 @@ AttributeUpdater::updateValue(ReferenceAttribute &vec, uint32_t lid, const Field
}
}
+namespace {
+
+void
+validate_tensor_attribute_type(AttributeVector& attr)
+{
+ const auto& info = attr.getClass();
+ if (!info.inherits(TensorAttribute::classId)) {
+ throw UpdateException(
+ make_string("Expected attribute vector '%s' to be a TensorAttribute, but was '%s'",
+ attr.getName().c_str(), info.name()));
+ }
+}
+
+std::unique_ptr<PrepareResult>
+prepare_set_tensor(TensorAttribute& attr, uint32_t docid, const FieldValue& val)
+{
+ validate_field_value_type<TensorFieldValue>(val, "TensorAttribute", "TensorFieldValue");
+ const auto& tensor = static_cast<const TensorFieldValue&>(val).getAsTensorPtr();
+ if (tensor) {
+ return attr.prepare_set_tensor(docid, *tensor);
+ }
+ return std::unique_ptr<PrepareResult>();
+}
+
+void
+complete_set_tensor(TensorAttribute& attr, uint32_t docid, const FieldValue& val, std::unique_ptr<PrepareResult> prepare_result)
+{
+ validate_field_value_type<TensorFieldValue>(val, "TensorAttribute", "TensorFieldValue");
+ const auto& tensor = static_cast<const TensorFieldValue&>(val).getAsTensorPtr();
+ if (tensor) {
+ attr.complete_set_tensor(docid, *tensor, std::move(prepare_result));
+ } else {
+ attr.clearDoc(docid);
+ }
+}
+
+}
+
+std::unique_ptr<PrepareResult>
+AttributeUpdater::prepare_set_value(AttributeVector& attr, uint32_t docid, const FieldValue& val)
+{
+ validate_tensor_attribute_type(attr);
+ return prepare_set_tensor(static_cast<TensorAttribute&>(attr), docid, val);
+}
+
+void
+AttributeUpdater::complete_set_value(AttributeVector& attr, uint32_t docid, const FieldValue& val,
+ std::unique_ptr<PrepareResult> prepare_result)
+{
+ validate_tensor_attribute_type(attr);
+ complete_set_tensor(static_cast<TensorAttribute&>(attr), docid, val, std::move(prepare_result));
+}
+
} // namespace search
diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.h b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.h
index 01be6299692..32d14f6dd5a 100644
--- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.h
+++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.h
@@ -10,7 +10,10 @@ namespace search {
class PredicateAttribute;
-namespace tensor { class TensorAttribute; }
+namespace tensor {
+ class PrepareResult;
+ class TensorAttribute;
+}
namespace attribute {class ReferenceAttribute; }
VESPA_DEFINE_EXCEPTION(UpdateException, vespalib::Exception);
@@ -20,14 +23,18 @@ VESPA_DEFINE_EXCEPTION(UpdateException, vespalib::Exception);
*/
class AttributeUpdater {
using Field = document::Field;
- using FieldValue = document::FieldValue;
using FieldUpdate = document::FieldUpdate;
+ using FieldValue = document::FieldValue;
using ValueUpdate = document::ValueUpdate;
public:
static void handleUpdate(AttributeVector & vec, uint32_t lid, const FieldUpdate & upd);
static void handleValue(AttributeVector & vec, uint32_t lid, const FieldValue & val);
+ static std::unique_ptr<tensor::PrepareResult> prepare_set_value(AttributeVector& attr, uint32_t docid, const FieldValue& val);
+ static void complete_set_value(AttributeVector& attr, uint32_t docid, const FieldValue& val,
+ std::unique_ptr<tensor::PrepareResult> prepare_result);
+
private:
template <typename V>
static void handleUpdate(V & vec, uint32_t lid, const ValueUpdate & upd);
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp
index 979eedec58a..c2a1243f341 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp
@@ -263,7 +263,7 @@ TensorAttribute::prepare_set_tensor(DocId docid, const Tensor& tensor) const
void
TensorAttribute::complete_set_tensor(DocId docid, const Tensor& tensor,
- std::future<std::unique_ptr<PrepareResult>> prepare_result)
+ std::unique_ptr<PrepareResult> prepare_result)
{
(void) docid;
(void) tensor;
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h
index f752b9f7f2e..8380e485172 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h
@@ -7,7 +7,6 @@
#include "tensor_store.h"
#include <vespa/searchlib/attribute/not_implemented_attribute.h>
#include <vespa/vespalib/util/rcuvector.h>
-#include <future>
namespace search::tensor {
@@ -66,9 +65,9 @@ public:
* Performs the complete step in a two-phase operation to set a tensor for a document.
*
* This function is only called by the attribute writer thread.
- * It must wait for the result from the prepare step (via the future) before it does the modifying changes.
+ * It uses the result from the prepare step to do the modifying changes.
*/
- virtual void complete_set_tensor(DocId docid, const Tensor& tensor, std::future<std::unique_ptr<PrepareResult>> prepare_result);
+ virtual void complete_set_tensor(DocId docid, const Tensor& tensor, std::unique_ptr<PrepareResult> prepare_result);
virtual void compactWorst() = 0;
};