aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
blob: eab75537071574ec4c9ad58a5a07c2b2e38ccc00 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once

#include "distance_function.h"
#include "distance_function_factory.h"
#include "i_tensor_attribute.h"
#include "vector_bundle.h"
#include <optional>

namespace vespalib::eval { struct Value; }

namespace search::attribute { class IAttributeVector; }

namespace search::tensor {

/**
 * Class used to calculate the distance between two n-dimensional vectors,
 * where one is stored in a TensorAttribute and the other comes from the query.
 *
 * The distance function to use is defined in the TensorAttribute.
 */
class DistanceCalculator {
private:
    const tensor::ITensorAttribute& _attr_tensor;
    const vespalib::eval::Value* _query_tensor;
    std::unique_ptr<BoundDistanceFunction> _dist_fun;

public:
    DistanceCalculator(const tensor::ITensorAttribute& attr_tensor,
                       const vespalib::eval::Value& query_tensor_in);

    ~DistanceCalculator();

    const tensor::ITensorAttribute& attribute_tensor() const { return _attr_tensor; }
    const vespalib::eval::Value& query_tensor() const {
        assert(_query_tensor != nullptr);
        return *_query_tensor;
    }
    const BoundDistanceFunction& function() const { return *_dist_fun; }

    double calc_raw_score(uint32_t docid) const {
        auto vectors = _attr_tensor.get_vectors(docid);
        double result = _dist_fun->min_rawscore();
        for (uint32_t i = 0; i < vectors.subspaces(); ++i) {
            double distance = _dist_fun->calc(vectors.cells(i));
            double score = _dist_fun->to_rawscore(distance);
            result = std::max(result, score);
        }
        return result;
    }

    double calc_with_limit(uint32_t docid, double limit) const {
        auto vectors = _attr_tensor.get_vectors(docid);
        double result = std::numeric_limits<double>::max();
        for (uint32_t i = 0; i < vectors.subspaces(); ++i) {
            double distance = _dist_fun->calc_with_limit(vectors.cells(i), limit);
            result = std::min(result, distance);
        }
        return result;
    }

    void calc_closest_subspace(VectorBundle vectors, std::optional<uint32_t>& closest_subspace, double& best_distance) {
        for (uint32_t i = 0; i < vectors.subspaces(); ++i) {
            double distance = _dist_fun->calc(vectors.cells(i));
            if (!closest_subspace.has_value() || distance < best_distance) {
                best_distance = distance;
                closest_subspace = i;
            }
        }
    }

    std::optional<uint32_t> calc_closest_subspace(VectorBundle vectors) {
        double best_distance = 0.0;
        std::optional<uint32_t> closest_subspace;
        calc_closest_subspace(vectors, closest_subspace, best_distance);
        return closest_subspace;
    }

    /**
     * Create a calculator for the given attribute tensor and query tensor, if possible.
     *
     * Throws vespalib::IllegalArgumentException if the inputs are not supported or incompatible.
     */
    static std::unique_ptr<DistanceCalculator> make_with_validation(const search::attribute::IAttributeVector& attr,
                                                                    const vespalib::eval::Value& query_tensor_in);

};

}