From 03f71b36ce970a0207108702e4d1b6bf9b1fcabb Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 14 Jun 2018 11:17:31 +0200 Subject: Refactor out normal distributed random generator --- .../searchlib/features/random_normal_feature.cpp | 46 ++---------------- .../searchlib/features/random_normal_feature.h | 11 ++--- searchlib/src/vespa/searchlib/util/random_normal.h | 56 ++++++++++++++++++++++ 3 files changed, 65 insertions(+), 48 deletions(-) create mode 100644 searchlib/src/vespa/searchlib/util/random_normal.h diff --git a/searchlib/src/vespa/searchlib/features/random_normal_feature.cpp b/searchlib/src/vespa/searchlib/features/random_normal_feature.cpp index c15f072aaa6..ddf9f9f016a 100644 --- a/searchlib/src/vespa/searchlib/features/random_normal_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/random_normal_feature.cpp @@ -4,7 +4,6 @@ #include "utils.h" #include #include -#include #include LOG_SETUP(".features.randomnormalfeature"); @@ -18,53 +17,18 @@ RandomNormalExecutor::RandomNormalExecutor(uint64_t seed, uint64_t matchSeed, do _matchRnd(), _matchSeed(matchSeed), _mean(mean), - _stddev(stddev), - _hasSpare(false), - _spare(0.0) + _stddev(stddev) { LOG(debug, "RandomNormalExecutor: seed=%zu, matchSeed=%zu, mean=%f, stddev=%f", seed, matchSeed, mean, stddev); - _rnd.srand48(seed); + _rnd.seed(seed); } -feature_t generateRandom(Rand48& generator) { - return (generator.lrand48() / (feature_t)0x80000000u) * 2.0 - 1.0; -} - -/** - * Draws a random number from the Gaussian distribution - * using the Marsaglia polar method. - */ void RandomNormalExecutor::execute(uint32_t docId) { - feature_t result = _spare; - if (_hasSpare) { - _hasSpare = false; - } else { - _hasSpare = true; - - feature_t u, v, s; - do { - u = generateRandom(_rnd); - v = generateRandom(_rnd); - s = u * u + v * v; - } while ( (s >= 1.0) || (s == 0.0) ); - s = std::sqrt(-2.0 * std::log(s) / s); - - _spare = v * s; // saved for next invocation - result = u * s; - } - outputs().set_number(0, _mean + _stddev * result); - - _matchRnd.srand48(_matchSeed + docId); - feature_t u, v, s; - do { - u = generateRandom(_matchRnd); - v = generateRandom(_matchRnd); - s = u * u + v * v; - } while ( (s >= 1.0) || (s == 0.0) ); - s = std::sqrt(-2.0 * std::log(s) / s); - outputs().set_number(1, _mean + _stddev * u * s); + outputs().set_number(0, _mean + _stddev * _rnd.next()); + _matchRnd.seed(_matchSeed + docId); + outputs().set_number(0, _mean + _stddev * _matchRnd.next(false)); } RandomNormalBlueprint::RandomNormalBlueprint() : diff --git a/searchlib/src/vespa/searchlib/features/random_normal_feature.h b/searchlib/src/vespa/searchlib/features/random_normal_feature.h index f2bc82704bb..9ce8f899446 100644 --- a/searchlib/src/vespa/searchlib/features/random_normal_feature.h +++ b/searchlib/src/vespa/searchlib/features/random_normal_feature.h @@ -4,7 +4,7 @@ #include #include -#include +#include namespace search { namespace features { @@ -17,16 +17,13 @@ namespace features { **/ class RandomNormalExecutor : public fef::FeatureExecutor { private: - Rand48 _rnd; // seeded once per query - Rand48 _matchRnd; // seeded once per match - uint64_t _matchSeed; + RandomNormal _rnd; // seeded once per query + RandomNormal _matchRnd; // seeded once per match + uint64_t _matchSeed; double _mean; double _stddev; - bool _hasSpare; - double _spare; - public: RandomNormalExecutor(uint64_t seed, uint64_t matchSeed, double mean, double stddev); void execute(uint32_t docId) override; diff --git a/searchlib/src/vespa/searchlib/util/random_normal.h b/searchlib/src/vespa/searchlib/util/random_normal.h new file mode 100644 index 00000000000..0c2da580db6 --- /dev/null +++ b/searchlib/src/vespa/searchlib/util/random_normal.h @@ -0,0 +1,56 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include +#include + +namespace search { + +class RandomNormal +{ +private: + Rand48 _rnd; + bool _hasSpare; + feature_t _spare; + + feature_t nextUniform() { + return (_rnd.lrand48() / (feature_t)0x80000000u) * 2.0 - 1.0; + } + +public: + RandomNormal() : _rnd(), _hasSpare(false), _spare(0.0) {} + + void seed(long seed) { + _rnd.srand48(seed); + } + + /** + * Draws a random number from the Gaussian distribution + * using the Marsaglia polar method. + */ + feature_t next(bool useSpare = true) { + feature_t result = _spare; + if (_hasSpare && useSpare) { + _hasSpare = false; + } else { + _hasSpare = true; + + feature_t u, v, s; + do { + u = nextUniform(); + v = nextUniform(); + s = u * u + v * v; + } while ( (s >= 1.0) || (s == 0.0) ); + s = std::sqrt(-2.0 * std::log(s) / s); + + _spare = v * s; // saved for next invocation + result = u * s; + } + return result; + } + +}; + +} // search + -- cgit v1.2.3