diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2020-08-28 16:10:16 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-28 16:10:16 +0200 |
commit | 5f13cd33ac9015e0bd17f9c2c8080e8183480b92 (patch) | |
tree | 57e7e21556a952bee1d605592a6c343fc60ab8ee | |
parent | 9a8332e0fc7f38e914b51dcda5fed8aad73a044f (diff) | |
parent | b0b3699572067be679a393d41fe495c3fad3f85f (diff) |
Merge pull request #14185 from vespa-engine/arnej/load-direct-tensor-attribute
Arnej/load direct tensor attribute
9 files changed, 152 insertions, 33 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt index 35615b255c0..8e5405ee179 100644 --- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt @@ -5,6 +5,7 @@ vespa_add_library(searchlib_tensor OBJECT dense_tensor_attribute.cpp dense_tensor_attribute_saver.cpp dense_tensor_store.cpp + direct_tensor_attribute.cpp distance_function_factory.cpp distance_functions.cpp generic_tensor_attribute.cpp @@ -20,6 +21,7 @@ vespa_add_library(searchlib_tensor OBJECT nearest_neighbor_index.cpp nearest_neighbor_index_saver.cpp tensor_attribute.cpp + tensor_deserialize.cpp tensor_store.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h b/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h new file mode 100644 index 00000000000..7c34b60e93d --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h @@ -0,0 +1,26 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/fastlib/io/bufferedfile.h> +#include <vespa/searchlib/attribute/readerbase.h> +#include <vespa/searchlib/util/fileutil.h> + +namespace search::tensor { + +/** + * Utility for reading an attribute data file where + * the format is a sequence of blobs (size, byte[size]). + **/ +class BlobSequenceReader : public ReaderBase +{ +private: + FileReader<uint32_t> _sizeReader; +public: + BlobSequenceReader(AttributeVector &attr) + : ReaderBase(attr), + _sizeReader(*_datFile) + { } + uint32_t getNextSize() { return _sizeReader.readHostOrder(); } + void readBlob(void *buf, size_t len) { _datFile->ReadBuf(buf, len); } +}; + +} // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp index 76533839de7..37a042d4e7f 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp @@ -30,26 +30,26 @@ namespace { constexpr uint32_t DENSE_TENSOR_ATTRIBUTE_VERSION = 1; const vespalib::string tensorTypeTag("tensortype"); -class TensorReader : public ReaderBase +class BlobSequenceReader : public ReaderBase { private: static constexpr uint8_t tensorIsNotPresent = 0; static constexpr uint8_t tensorIsPresent = 1; public: - TensorReader(AttributeVector &attr); - ~TensorReader(); + BlobSequenceReader(AttributeVector &attr); + ~BlobSequenceReader(); bool is_present(); void readTensor(void *buf, size_t len) { _datFile->ReadBuf(buf, len); } }; -TensorReader::TensorReader(AttributeVector &attr) +BlobSequenceReader::BlobSequenceReader(AttributeVector &attr) : ReaderBase(attr) { } -TensorReader::~TensorReader() = default; +BlobSequenceReader::~BlobSequenceReader() = default; bool -TensorReader::is_present() { +BlobSequenceReader::is_present() { unsigned char detect; _datFile->ReadBuf(&detect, sizeof(detect)); if (detect == tensorIsNotPresent) { @@ -190,7 +190,7 @@ DenseTensorAttribute::getTensor(DocId docId, MutableDenseTensorView &tensor) con bool DenseTensorAttribute::onLoad() { - TensorReader tensorReader(*this); + BlobSequenceReader tensorReader(*this); if (!tensorReader.hasData()) { return false; } diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp new file mode 100644 index 00000000000..f53d42442ba --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp @@ -0,0 +1,52 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "direct_tensor_attribute.h" + +#include <vespa/eval/tensor/tensor.h> +#include <vespa/fastlib/io/bufferedfile.h> +#include <vespa/searchlib/attribute/readerbase.h> +#include <vespa/searchlib/util/fileutil.h> +#include <vespa/vespalib/util/array.h> + +#include "blob_sequence_reader.h" +#include "tensor_deserialize.h" + +using vespalib::tensor::Tensor; + +namespace search::tensor { + +constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0; + +bool +DirectTensorAttribute::onLoad() +{ + BlobSequenceReader tensorReader(*this); + if (!tensorReader.hasData()) { + return false; + } + setCreateSerialNum(tensorReader.getCreateSerialNum()); + assert(tensorReader.getVersion() == TENSOR_ATTRIBUTE_VERSION); + uint32_t numDocs = tensorReader.getDocIdLimit(); + vespalib::Array<char> buffer(1024); + for (uint32_t lid = 0; lid < numDocs; ++lid) { + uint32_t tensorSize = tensorReader.getNextSize(); + if (tensorSize != 0) { + if (tensorSize > buffer.size()) { + buffer.resize(tensorSize + 1024); + } + tensorReader.readBlob(&buffer[0], tensorSize); + setTensor(lid, deserialize_tensor(&buffer[0], tensorSize)); + } + } + setNumDocs(numDocs); + setCommittedDocIdLimit(numDocs); + return true; +} + +void +DirectTensorAttribute::setTensor(DocId , std::unique_ptr<Tensor> ) +{ + // XXX missing +} + +} // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h new file mode 100644 index 00000000000..ae3cb222dba --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h @@ -0,0 +1,25 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "tensor_attribute.h" + +namespace search::tensor { + +class DirectTensorAttribute : public TensorAttribute +{ + // XXX must have some sort of TensorStore here +public: + DirectTensorAttribute(vespalib::stringref baseFileName, const Config &cfg); + virtual ~DirectTensorAttribute(); + virtual void setTensor(DocId docId, const Tensor &tensor) override; + virtual std::unique_ptr<Tensor> getTensor(DocId docId) const override; + virtual void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override; + virtual bool onLoad() override; + virtual std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override; + virtual void compactWorst() override; + + void setTensor(DocId docId, std::unique_ptr<Tensor> tensor); +}; + +} // namespace search::tensor diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp index aac199ae818..6864fb52120 100644 --- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp @@ -3,6 +3,7 @@ #include "generic_tensor_attribute.h" #include "generic_tensor_attribute_saver.h" #include "tensor_attribute.hpp" +#include "blob_sequence_reader.h" #include <vespa/eval/tensor/tensor.h> #include <vespa/fastlib/io/bufferedfile.h> #include <vespa/searchlib/attribute/readerbase.h> @@ -18,19 +19,6 @@ namespace { constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0; -class TensorReader : public ReaderBase -{ -private: - FileReader<uint32_t> _tensorSizeReader; -public: - TensorReader(AttributeVector &attr) - : ReaderBase(attr), - _tensorSizeReader(*_datFile) - { } - uint32_t getNextTensorSize() { return _tensorSizeReader.readHostOrder(); } - void readTensor(void *buf, size_t len) { _datFile->ReadBuf(buf, len); } -}; - } GenericTensorAttribute::GenericTensorAttribute(stringref name, const Config &cfg) @@ -76,7 +64,7 @@ GenericTensorAttribute::getTensor(DocId, vespalib::tensor::MutableDenseTensorVie bool GenericTensorAttribute::onLoad() { - TensorReader tensorReader(*this); + BlobSequenceReader tensorReader(*this); if (!tensorReader.hasData()) { return false; } @@ -86,10 +74,10 @@ GenericTensorAttribute::onLoad() _refVector.reset(); _refVector.unsafe_reserve(numDocs); for (uint32_t lid = 0; lid < numDocs; ++lid) { - uint32_t tensorSize = tensorReader.getNextTensorSize(); + uint32_t tensorSize = tensorReader.getNextSize(); auto raw = _genericTensorStore.allocRawBuffer(tensorSize); if (tensorSize != 0) { - tensorReader.readTensor(raw.data, tensorSize); + tensorReader.readBlob(raw.data, tensorSize); } _refVector.push_back(raw.ref); } diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp index f19bef3ff21..8c695c32719 100644 --- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp @@ -1,15 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "generic_tensor_store.h" +#include "tensor_deserialize.h" #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/serialization/typed_binary_format.h> -#include <vespa/document/util/serializableexceptions.h> #include <vespa/vespalib/datastore/datastore.hpp> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/macro.h> -using document::DeserializeException; using vespalib::datastore::Handle; using vespalib::tensor::Tensor; using vespalib::tensor::TypedBinaryFormat; @@ -95,14 +94,7 @@ GenericTensorStore::getTensor(EntryRef ref) const if (raw.second == 0u) { return std::unique_ptr<Tensor>(); } - vespalib::nbostream wrapStream(raw.first, raw.second); - auto tensor = TypedBinaryFormat::deserialize(wrapStream); - if (wrapStream.size() != 0) { - throw DeserializeException("Leftover bytes deserializing " - "tensor attribute value.", - VESPA_STRLOC); - } - return tensor; + return deserialize_tensor(raw.first, raw.second); } TensorStore::EntryRef diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp new file mode 100644 index 00000000000..7998fba5941 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp @@ -0,0 +1,24 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/document/util/serializableexceptions.h> +#include <vespa/eval/tensor/serialization/typed_binary_format.h> +#include <vespa/eval/tensor/tensor.h> +#include <vespa/vespalib/objects/nbostream.h> + +using document::DeserializeException; +using vespalib::tensor::Tensor; +using vespalib::tensor::TypedBinaryFormat; + +namespace search::tensor { + +std::unique_ptr<Tensor> deserialize_tensor(const void *data, size_t size) +{ + vespalib::nbostream wrapStream(data, size); + auto tensor = TypedBinaryFormat::deserialize(wrapStream); + if (wrapStream.size() != 0) { + throw DeserializeException("Leftover bytes deserializing tensor attribute value.", VESPA_STRLOC); + } + return tensor; +} + +} // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h new file mode 100644 index 00000000000..f1dfa1ca173 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h @@ -0,0 +1,10 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/tensor/tensor.h> + +namespace search::tensor { + +extern std::unique_ptr<vespalib::tensor::Tensor> +deserialize_tensor(const void *data, size_t size); + +} // namespace |