aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h
blob: 3aa52fe622aebe5a4d77b38c14d8d34d891d5081 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include "default_nearest_neighbor_index_factory.h"
#include "dense_tensor_store.h"
#include "doc_vector_access.h"
#include "tensor_attribute.h"
#include "typed_cells_comparator.h"
#include <memory>

namespace search::tensor {

class NearestNeighborIndex;

/**
 * Attribute vector class used to store dense tensors for all
 * documents in memory.
 */
class DenseTensorAttribute : public TensorAttribute, public DocVectorAccess {
private:
    DenseTensorStore _denseTensorStore;
    std::unique_ptr<NearestNeighborIndex> _index;
    TypedCellsComparator _comp;

    bool tensor_is_unchanged(DocId docid, const vespalib::eval::Value& new_tensor) const;
    void internal_set_tensor(DocId docid, const vespalib::eval::Value& tensor);
    void consider_remove_from_index(DocId docid);
    vespalib::MemoryUsage update_stat() override;
    vespalib::MemoryUsage memory_usage() const override;
    void populate_address_space_usage(AddressSpaceUsage& usage) const override;
    class ThreadedLoader;
    class ForegroundLoader;
public:
    DenseTensorAttribute(vespalib::stringref baseFileName, const Config& cfg,
                         const NearestNeighborIndexFactory& index_factory = DefaultNearestNeighborIndexFactory());
    ~DenseTensorAttribute() override;
    // Implements AttributeVector and ITensorAttribute
    uint32_t clearDoc(DocId docId) override;
    void setTensor(DocId docId, const vespalib::eval::Value &tensor) override;
    std::unique_ptr<PrepareResult> prepare_set_tensor(DocId docid, const vespalib::eval::Value& tensor) const override;
    void complete_set_tensor(DocId docid, const vespalib::eval::Value& tensor, std::unique_ptr<PrepareResult> prepare_result) override;
    std::unique_ptr<vespalib::eval::Value> getTensor(DocId docId) const override;
    vespalib::eval::TypedCells extract_cells_ref(DocId docId) const override;
    bool supports_extract_cells_ref() const override { return true; }
    bool onLoad(vespalib::Executor *executor) override;
    std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override;
    uint32_t getVersion() const override;
    void onCommit() override;
    void before_inc_generation(generation_t current_gen) override;
    void reclaim_memory(generation_t oldest_used_gen) override;
    void get_state(const vespalib::slime::Inserter& inserter) const override;
    void onShrinkLidSpace() override;

    // Implements DocVectorAccess
    vespalib::eval::TypedCells get_vector(uint32_t docid) const override;

    const NearestNeighborIndex* nearest_neighbor_index() const override { return _index.get(); }
};

}