aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/attribute/numeric_range_matcher.h
blob: 7f1c3e3136725581195db80ad605948200ed2569 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include <vespa/searchcommon/common/range.h>
#include <cstddef>

namespace search { class QueryTermSimple; }

namespace search::attribute {

/*
 * Class used to determine if an attribute vector value is a match for
 * the query range.
 */
template<typename T>
class NumericRangeMatcher
{
protected:
    T _low;
    T _high;
private:
    bool _valid;
    int _limit;
    size_t _max_per_group;
public:
    NumericRangeMatcher(const QueryTermSimple& queryTerm) : NumericRangeMatcher(queryTerm, false) {}
    NumericRangeMatcher(const QueryTermSimple& queryTerm, bool avoidUndefinedInRange);
protected:
    Int64Range getRange() const {
        return {static_cast<int64_t>(_low), static_cast<int64_t>(_high)};
    }
    DoubleRange getDoubleRange() const {
        return {static_cast<double>(_low), static_cast<double>(_high)};
    }
    bool isValid() const { return _valid; }
    bool match(T v) const { return (_low <= v) && (v <= _high); }
    int getRangeLimit() const { return _limit; }
    size_t getMaxPerGroup() const { return _max_per_group; }

    template <typename BaseType>
    search::Range<BaseType>
    cappedRange(bool isFloat)
    {
        auto low = static_cast<BaseType>(_low);
        auto high = static_cast<BaseType>(_high);

        BaseType numMin = std::numeric_limits<BaseType>::min();
        BaseType numMax = std::numeric_limits<BaseType>::max();

        if (isFloat) {
            if (_low <= (-numMax)) {
                low = -numMax;
            }
        } else {
            if (_low <= (numMin)) {
                low = numMin + 1; // we must avoid the undefined value
            }
        }

        if (_high >= (numMax)) {
            high = numMax;
        }
        return search::Range<BaseType>(low, high);
    }
};

}