diff options
author | Geir Storli <geirst@yahooinc.com> | 2022-10-09 20:12:34 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-09 20:12:34 +0200 |
commit | 240a62de8a9b3c93fb9f7031f5e204264d414817 (patch) | |
tree | 1f5e0bb204f8a98d7bf8fdf0da48472de27a3ab8 /searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp | |
parent | 14e1b2febd40c3fa89d09b64569b4634f5594acc (diff) | |
parent | a45d929b88b17c9de0d53d4c6fc3b25815bbe233 (diff) |
Merge pull request #24367 from vespa-engine/toregge/share-code-for-loading-and-saving-tensor-attributev8.65.41
Share code for loading and saving tensor attribute between
Diffstat (limited to 'searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp')
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp index cdaea07176a..99a30b59bd1 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensor_attribute.h" +#include "blob_sequence_reader.h" +#include "tensor_store_saver.h" #include <vespa/document/base/exceptions.h> #include <vespa/document/datatype/tensor_data_type.h> #include <vespa/searchlib/attribute/address_space_components.h> @@ -277,6 +279,51 @@ TensorAttribute::getRefCopy() const return result; } +bool +TensorAttribute::onLoad(vespalib::Executor*) +{ + BlobSequenceReader tensorReader(*this); + if (!tensorReader.hasData()) { + return false; + } + setCreateSerialNum(tensorReader.getCreateSerialNum()); + assert(tensorReader.getVersion() == getVersion()); + uint32_t numDocs = tensorReader.getDocIdLimit(); + _refVector.reset(); + _refVector.unsafe_reserve(numDocs); + vespalib::Array<char> buffer(1024); + for (uint32_t lid = 0; lid < numDocs; ++lid) { + uint32_t tensorSize = tensorReader.getNextSize(); + if (tensorSize != 0) { + if (tensorSize > buffer.size()) { + buffer.resize(tensorSize + 1024); + } + tensorReader.readBlob(&buffer[0], tensorSize); + vespalib::nbostream source(&buffer[0], tensorSize); + EntryRef ref = _tensorStore.store_encoded_tensor(source); + _refVector.push_back(AtomicEntryRef(ref)); + } else { + EntryRef invalid; + _refVector.push_back(AtomicEntryRef(invalid)); + } + } + setNumDocs(numDocs); + setCommittedDocIdLimit(numDocs); + return true; +} + +std::unique_ptr<AttributeSaver> +TensorAttribute::onInitSave(vespalib::stringref fileName) +{ + vespalib::GenerationHandler::Guard guard(getGenerationHandler(). + takeGuard()); + return std::make_unique<TensorStoreSaver> + (std::move(guard), + this->createAttributeHeader(fileName), + getRefCopy(), + _tensorStore); +} + void TensorAttribute::update_tensor(DocId docId, const document::TensorUpdate &update, |