summaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp
blob: 07377630299240747794c6f4d9835e7a874ac46d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "generic_tensor_attribute.h"
#include "generic_tensor_attribute_saver.h"
#include "tensor_attribute.hpp"
#include <vespa/eval/tensor/tensor.h>
#include <vespa/searchlib/common/rcuvector.hpp>
#include <vespa/fastlib/io/bufferedfile.h>
#include <vespa/searchlib/attribute/readerbase.h>
#include <vespa/searchlib/util/fileutil.h>

using vespalib::eval::ValueType;
using vespalib::tensor::Tensor;
using vespalib::tensor::TensorMapper;

namespace search {

namespace tensor {

namespace {

constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0;

class TensorReader : public ReaderBase
{
private:
    FileReader<uint32_t> _tensorSizeReader;
public:
    TensorReader(AttributeVector &attr)
        : ReaderBase(attr),
          _tensorSizeReader(*_datFile)
    { }
    uint32_t getNextTensorSize() { return _tensorSizeReader.readHostOrder(); }
    void readTensor(void *buf, size_t len) { _datFile->ReadBuf(buf, len); }
};

}

GenericTensorAttribute::GenericTensorAttribute(const vespalib::stringref &baseFileName, const Config &cfg)
    : TensorAttribute(baseFileName, cfg, _genericTensorStore)
{
}


GenericTensorAttribute::~GenericTensorAttribute()
{
    getGenerationHolder().clearHoldLists();
    _tensorStore.clearHoldLists();
}

void
GenericTensorAttribute::setTensor(DocId docId, const Tensor &tensor)
{
    RefType ref = _genericTensorStore.setTensor(
            (_tensorMapper ? *_tensorMapper->map(tensor) : tensor));
    setTensorRef(docId, ref);
}


std::unique_ptr<Tensor>
GenericTensorAttribute::getTensor(DocId docId) const
{
    RefType ref;
    if (docId < getCommittedDocIdLimit()) {
        ref = _refVector[docId];
    }
    if (!ref.valid()) {
        return std::unique_ptr<Tensor>();
    }
    return _genericTensorStore.getTensor(ref);
}

bool
GenericTensorAttribute::onLoad()
{
    TensorReader tensorReader(*this);
    if (!tensorReader.hasData()) {
        return false;
    }
    setCreateSerialNum(tensorReader.getCreateSerialNum());
    assert(tensorReader.getVersion() == TENSOR_ATTRIBUTE_VERSION);
    uint32_t numDocs(tensorReader.getDocIdLimit());
    _refVector.reset();
    _refVector.unsafe_reserve(numDocs);
    for (uint32_t lid = 0; lid < numDocs; ++lid) {
        uint32_t tensorSize = tensorReader.getNextTensorSize();
        auto raw = _genericTensorStore.allocRawBuffer(tensorSize);
        if (tensorSize != 0) {
            tensorReader.readTensor(raw.data, tensorSize);
        }
        _refVector.push_back(raw.ref);
    }
    setNumDocs(numDocs);
    setCommittedDocIdLimit(numDocs);
    return true;
}


std::unique_ptr<AttributeSaver>
GenericTensorAttribute::onInitSave()
{
    vespalib::GenerationHandler::Guard guard(getGenerationHandler().
                                             takeGuard());
    return std::make_unique<GenericTensorAttributeSaver>
        (std::move(guard),
         this->createAttributeHeader(),
         getRefCopy(),
         _genericTensorStore);
}

void
GenericTensorAttribute::compactWorst()
{
    doCompactWorst<GenericTensorStore::RefType>();
}


}  // namespace search::tensor

}  // namespace search