aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2020-08-28 16:10:16 +0200
committerGitHub <noreply@github.com>2020-08-28 16:10:16 +0200
commit5f13cd33ac9015e0bd17f9c2c8080e8183480b92 (patch)
tree57e7e21556a952bee1d605592a6c343fc60ab8ee
parent9a8332e0fc7f38e914b51dcda5fed8aad73a044f (diff)
parentb0b3699572067be679a393d41fe495c3fad3f85f (diff)
Merge pull request #14185 from vespa-engine/arnej/load-direct-tensor-attribute
Arnej/load direct tensor attribute
-rw-r--r--searchlib/src/vespa/searchlib/tensor/CMakeLists.txt2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h26
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp14
-rw-r--r--searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp52
-rw-r--r--searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h25
-rw-r--r--searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp20
-rw-r--r--searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp24
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h10
9 files changed, 152 insertions, 33 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
index 35615b255c0..8e5405ee179 100644
--- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
@@ -5,6 +5,7 @@ vespa_add_library(searchlib_tensor OBJECT
dense_tensor_attribute.cpp
dense_tensor_attribute_saver.cpp
dense_tensor_store.cpp
+ direct_tensor_attribute.cpp
distance_function_factory.cpp
distance_functions.cpp
generic_tensor_attribute.cpp
@@ -20,6 +21,7 @@ vespa_add_library(searchlib_tensor OBJECT
nearest_neighbor_index.cpp
nearest_neighbor_index_saver.cpp
tensor_attribute.cpp
+ tensor_deserialize.cpp
tensor_store.cpp
DEPENDS
)
diff --git a/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h b/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h
new file mode 100644
index 00000000000..7c34b60e93d
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h
@@ -0,0 +1,26 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/fastlib/io/bufferedfile.h>
+#include <vespa/searchlib/attribute/readerbase.h>
+#include <vespa/searchlib/util/fileutil.h>
+
+namespace search::tensor {
+
+/**
+ * Utility for reading an attribute data file where
+ * the format is a sequence of blobs (size, byte[size]).
+ **/
+class BlobSequenceReader : public ReaderBase
+{
+private:
+ FileReader<uint32_t> _sizeReader;
+public:
+ BlobSequenceReader(AttributeVector &attr)
+ : ReaderBase(attr),
+ _sizeReader(*_datFile)
+ { }
+ uint32_t getNextSize() { return _sizeReader.readHostOrder(); }
+ void readBlob(void *buf, size_t len) { _datFile->ReadBuf(buf, len); }
+};
+
+} // namespace
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
index 76533839de7..37a042d4e7f 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
@@ -30,26 +30,26 @@ namespace {
constexpr uint32_t DENSE_TENSOR_ATTRIBUTE_VERSION = 1;
const vespalib::string tensorTypeTag("tensortype");
-class TensorReader : public ReaderBase
+class BlobSequenceReader : public ReaderBase
{
private:
static constexpr uint8_t tensorIsNotPresent = 0;
static constexpr uint8_t tensorIsPresent = 1;
public:
- TensorReader(AttributeVector &attr);
- ~TensorReader();
+ BlobSequenceReader(AttributeVector &attr);
+ ~BlobSequenceReader();
bool is_present();
void readTensor(void *buf, size_t len) { _datFile->ReadBuf(buf, len); }
};
-TensorReader::TensorReader(AttributeVector &attr)
+BlobSequenceReader::BlobSequenceReader(AttributeVector &attr)
: ReaderBase(attr)
{
}
-TensorReader::~TensorReader() = default;
+BlobSequenceReader::~BlobSequenceReader() = default;
bool
-TensorReader::is_present() {
+BlobSequenceReader::is_present() {
unsigned char detect;
_datFile->ReadBuf(&detect, sizeof(detect));
if (detect == tensorIsNotPresent) {
@@ -190,7 +190,7 @@ DenseTensorAttribute::getTensor(DocId docId, MutableDenseTensorView &tensor) con
bool
DenseTensorAttribute::onLoad()
{
- TensorReader tensorReader(*this);
+ BlobSequenceReader tensorReader(*this);
if (!tensorReader.hasData()) {
return false;
}
diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp
new file mode 100644
index 00000000000..f53d42442ba
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp
@@ -0,0 +1,52 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "direct_tensor_attribute.h"
+
+#include <vespa/eval/tensor/tensor.h>
+#include <vespa/fastlib/io/bufferedfile.h>
+#include <vespa/searchlib/attribute/readerbase.h>
+#include <vespa/searchlib/util/fileutil.h>
+#include <vespa/vespalib/util/array.h>
+
+#include "blob_sequence_reader.h"
+#include "tensor_deserialize.h"
+
+using vespalib::tensor::Tensor;
+
+namespace search::tensor {
+
+constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0;
+
+bool
+DirectTensorAttribute::onLoad()
+{
+ BlobSequenceReader tensorReader(*this);
+ if (!tensorReader.hasData()) {
+ return false;
+ }
+ setCreateSerialNum(tensorReader.getCreateSerialNum());
+ assert(tensorReader.getVersion() == TENSOR_ATTRIBUTE_VERSION);
+ uint32_t numDocs = tensorReader.getDocIdLimit();
+ vespalib::Array<char> buffer(1024);
+ for (uint32_t lid = 0; lid < numDocs; ++lid) {
+ uint32_t tensorSize = tensorReader.getNextSize();
+ if (tensorSize != 0) {
+ if (tensorSize > buffer.size()) {
+ buffer.resize(tensorSize + 1024);
+ }
+ tensorReader.readBlob(&buffer[0], tensorSize);
+ setTensor(lid, deserialize_tensor(&buffer[0], tensorSize));
+ }
+ }
+ setNumDocs(numDocs);
+ setCommittedDocIdLimit(numDocs);
+ return true;
+}
+
+void
+DirectTensorAttribute::setTensor(DocId , std::unique_ptr<Tensor> )
+{
+ // XXX missing
+}
+
+} // namespace
diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h
new file mode 100644
index 00000000000..ae3cb222dba
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h
@@ -0,0 +1,25 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "tensor_attribute.h"
+
+namespace search::tensor {
+
+class DirectTensorAttribute : public TensorAttribute
+{
+ // XXX must have some sort of TensorStore here
+public:
+ DirectTensorAttribute(vespalib::stringref baseFileName, const Config &cfg);
+ virtual ~DirectTensorAttribute();
+ virtual void setTensor(DocId docId, const Tensor &tensor) override;
+ virtual std::unique_ptr<Tensor> getTensor(DocId docId) const override;
+ virtual void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override;
+ virtual bool onLoad() override;
+ virtual std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override;
+ virtual void compactWorst() override;
+
+ void setTensor(DocId docId, std::unique_ptr<Tensor> tensor);
+};
+
+} // namespace search::tensor
diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
index aac199ae818..6864fb52120 100644
--- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
@@ -3,6 +3,7 @@
#include "generic_tensor_attribute.h"
#include "generic_tensor_attribute_saver.h"
#include "tensor_attribute.hpp"
+#include "blob_sequence_reader.h"
#include <vespa/eval/tensor/tensor.h>
#include <vespa/fastlib/io/bufferedfile.h>
#include <vespa/searchlib/attribute/readerbase.h>
@@ -18,19 +19,6 @@ namespace {
constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0;
-class TensorReader : public ReaderBase
-{
-private:
- FileReader<uint32_t> _tensorSizeReader;
-public:
- TensorReader(AttributeVector &attr)
- : ReaderBase(attr),
- _tensorSizeReader(*_datFile)
- { }
- uint32_t getNextTensorSize() { return _tensorSizeReader.readHostOrder(); }
- void readTensor(void *buf, size_t len) { _datFile->ReadBuf(buf, len); }
-};
-
}
GenericTensorAttribute::GenericTensorAttribute(stringref name, const Config &cfg)
@@ -76,7 +64,7 @@ GenericTensorAttribute::getTensor(DocId, vespalib::tensor::MutableDenseTensorVie
bool
GenericTensorAttribute::onLoad()
{
- TensorReader tensorReader(*this);
+ BlobSequenceReader tensorReader(*this);
if (!tensorReader.hasData()) {
return false;
}
@@ -86,10 +74,10 @@ GenericTensorAttribute::onLoad()
_refVector.reset();
_refVector.unsafe_reserve(numDocs);
for (uint32_t lid = 0; lid < numDocs; ++lid) {
- uint32_t tensorSize = tensorReader.getNextTensorSize();
+ uint32_t tensorSize = tensorReader.getNextSize();
auto raw = _genericTensorStore.allocRawBuffer(tensorSize);
if (tensorSize != 0) {
- tensorReader.readTensor(raw.data, tensorSize);
+ tensorReader.readBlob(raw.data, tensorSize);
}
_refVector.push_back(raw.ref);
}
diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp
index f19bef3ff21..8c695c32719 100644
--- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp
@@ -1,15 +1,14 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "generic_tensor_store.h"
+#include "tensor_deserialize.h"
#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/serialization/typed_binary_format.h>
-#include <vespa/document/util/serializableexceptions.h>
#include <vespa/vespalib/datastore/datastore.hpp>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/macro.h>
-using document::DeserializeException;
using vespalib::datastore::Handle;
using vespalib::tensor::Tensor;
using vespalib::tensor::TypedBinaryFormat;
@@ -95,14 +94,7 @@ GenericTensorStore::getTensor(EntryRef ref) const
if (raw.second == 0u) {
return std::unique_ptr<Tensor>();
}
- vespalib::nbostream wrapStream(raw.first, raw.second);
- auto tensor = TypedBinaryFormat::deserialize(wrapStream);
- if (wrapStream.size() != 0) {
- throw DeserializeException("Leftover bytes deserializing "
- "tensor attribute value.",
- VESPA_STRLOC);
- }
- return tensor;
+ return deserialize_tensor(raw.first, raw.second);
}
TensorStore::EntryRef
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp
new file mode 100644
index 00000000000..7998fba5941
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp
@@ -0,0 +1,24 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/document/util/serializableexceptions.h>
+#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/eval/tensor/tensor.h>
+#include <vespa/vespalib/objects/nbostream.h>
+
+using document::DeserializeException;
+using vespalib::tensor::Tensor;
+using vespalib::tensor::TypedBinaryFormat;
+
+namespace search::tensor {
+
+std::unique_ptr<Tensor> deserialize_tensor(const void *data, size_t size)
+{
+ vespalib::nbostream wrapStream(data, size);
+ auto tensor = TypedBinaryFormat::deserialize(wrapStream);
+ if (wrapStream.size() != 0) {
+ throw DeserializeException("Leftover bytes deserializing tensor attribute value.", VESPA_STRLOC);
+ }
+ return tensor;
+}
+
+} // namespace
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h
new file mode 100644
index 00000000000..f1dfa1ca173
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h
@@ -0,0 +1,10 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/tensor/tensor.h>
+
+namespace search::tensor {
+
+extern std::unique_ptr<vespalib::tensor::Tensor>
+deserialize_tensor(const void *data, size_t size);
+
+} // namespace