aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/util/random_normal.h
blob: 18f51284b6e0229079e706a4f5895d9e1ea223af (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include <vespa/vespalib/util/rand48.h>
#include <cmath>

namespace search {

/**
 * Draws a random number from the Gaussian distribution
 * using the Marsaglia polar method.
 */
class RandomNormal
{
private:
    vespalib::Rand48    _rnd;
    double    _mean;
    double    _stddev;

    bool      _useSpare;
    bool      _hasSpare;
    feature_t _spare;

    feature_t nextUniform() {
        return (_rnd.lrand48() / (feature_t)0x80000000u) * 2.0 - 1.0;
    }

public:
    RandomNormal(double mean, double stddev, bool useSpare = true) :
            _rnd(),
            _mean(mean),
            _stddev(stddev),
            _useSpare(useSpare),
            _hasSpare(false),
            _spare(0.0)
    {}

    void seed(long seed) {
        _rnd.srand48(seed);
    }

    feature_t next() {
        feature_t result = _spare;
        if (_useSpare && _hasSpare) {
            _hasSpare = false;
        } else {
            _hasSpare = true;

            feature_t u, v, s;
            do {
                u = nextUniform();
                v = nextUniform();
                s = u * u + v * v;
            } while ( (s >= 1.0) || (s == 0.0) );
            s = std::sqrt(-2.0 * std::log(s) / s);

            _spare = v * s; // saved for next invocation
            result = u * s;
        }
        return _mean + _stddev * result;
    }

};

} // search