aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-06-14 11:17:31 +0200
committerLester Solbakken <lesters@oath.com>2018-06-14 11:17:31 +0200
commit03f71b36ce970a0207108702e4d1b6bf9b1fcabb (patch)
tree95e54551e6d371776bf1ead99f620d8f05f113fd /searchlib
parenta5cb701cad94076774b900edcc68757b6e20c93e (diff)
Refactor out normal distributed random generator
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/features/random_normal_feature.cpp46
-rw-r--r--searchlib/src/vespa/searchlib/features/random_normal_feature.h11
-rw-r--r--searchlib/src/vespa/searchlib/util/random_normal.h56
3 files changed, 65 insertions, 48 deletions
diff --git a/searchlib/src/vespa/searchlib/features/random_normal_feature.cpp b/searchlib/src/vespa/searchlib/features/random_normal_feature.cpp
index c15f072aaa6..ddf9f9f016a 100644
--- a/searchlib/src/vespa/searchlib/features/random_normal_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/random_normal_feature.cpp
@@ -4,7 +4,6 @@
#include "utils.h"
#include <vespa/searchlib/fef/properties.h>
#include <vespa/fastos/time.h>
-#include <cmath>
#include <vespa/log/log.h>
LOG_SETUP(".features.randomnormalfeature");
@@ -18,53 +17,18 @@ RandomNormalExecutor::RandomNormalExecutor(uint64_t seed, uint64_t matchSeed, do
_matchRnd(),
_matchSeed(matchSeed),
_mean(mean),
- _stddev(stddev),
- _hasSpare(false),
- _spare(0.0)
+ _stddev(stddev)
{
LOG(debug, "RandomNormalExecutor: seed=%zu, matchSeed=%zu, mean=%f, stddev=%f", seed, matchSeed, mean, stddev);
- _rnd.srand48(seed);
+ _rnd.seed(seed);
}
-feature_t generateRandom(Rand48& generator) {
- return (generator.lrand48() / (feature_t)0x80000000u) * 2.0 - 1.0;
-}
-
-/**
- * Draws a random number from the Gaussian distribution
- * using the Marsaglia polar method.
- */
void
RandomNormalExecutor::execute(uint32_t docId)
{
- feature_t result = _spare;
- if (_hasSpare) {
- _hasSpare = false;
- } else {
- _hasSpare = true;
-
- feature_t u, v, s;
- do {
- u = generateRandom(_rnd);
- v = generateRandom(_rnd);
- s = u * u + v * v;
- } while ( (s >= 1.0) || (s == 0.0) );
- s = std::sqrt(-2.0 * std::log(s) / s);
-
- _spare = v * s; // saved for next invocation
- result = u * s;
- }
- outputs().set_number(0, _mean + _stddev * result);
-
- _matchRnd.srand48(_matchSeed + docId);
- feature_t u, v, s;
- do {
- u = generateRandom(_matchRnd);
- v = generateRandom(_matchRnd);
- s = u * u + v * v;
- } while ( (s >= 1.0) || (s == 0.0) );
- s = std::sqrt(-2.0 * std::log(s) / s);
- outputs().set_number(1, _mean + _stddev * u * s);
+ outputs().set_number(0, _mean + _stddev * _rnd.next());
+ _matchRnd.seed(_matchSeed + docId);
+ outputs().set_number(0, _mean + _stddev * _matchRnd.next(false));
}
RandomNormalBlueprint::RandomNormalBlueprint() :
diff --git a/searchlib/src/vespa/searchlib/features/random_normal_feature.h b/searchlib/src/vespa/searchlib/features/random_normal_feature.h
index f2bc82704bb..9ce8f899446 100644
--- a/searchlib/src/vespa/searchlib/features/random_normal_feature.h
+++ b/searchlib/src/vespa/searchlib/features/random_normal_feature.h
@@ -4,7 +4,7 @@
#include <vespa/searchlib/fef/blueprint.h>
#include <vespa/searchlib/fef/featureexecutor.h>
-#include <vespa/searchlib/util/rand48.h>
+#include <vespa/searchlib/util/random_normal.h>
namespace search {
namespace features {
@@ -17,16 +17,13 @@ namespace features {
**/
class RandomNormalExecutor : public fef::FeatureExecutor {
private:
- Rand48 _rnd; // seeded once per query
- Rand48 _matchRnd; // seeded once per match
- uint64_t _matchSeed;
+ RandomNormal _rnd; // seeded once per query
+ RandomNormal _matchRnd; // seeded once per match
+ uint64_t _matchSeed;
double _mean;
double _stddev;
- bool _hasSpare;
- double _spare;
-
public:
RandomNormalExecutor(uint64_t seed, uint64_t matchSeed, double mean, double stddev);
void execute(uint32_t docId) override;
diff --git a/searchlib/src/vespa/searchlib/util/random_normal.h b/searchlib/src/vespa/searchlib/util/random_normal.h
new file mode 100644
index 00000000000..0c2da580db6
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/util/random_normal.h
@@ -0,0 +1,56 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/searchlib/util/rand48.h>
+#include <cmath>
+
+namespace search {
+
+class RandomNormal
+{
+private:
+ Rand48 _rnd;
+ bool _hasSpare;
+ feature_t _spare;
+
+ feature_t nextUniform() {
+ return (_rnd.lrand48() / (feature_t)0x80000000u) * 2.0 - 1.0;
+ }
+
+public:
+ RandomNormal() : _rnd(), _hasSpare(false), _spare(0.0) {}
+
+ void seed(long seed) {
+ _rnd.srand48(seed);
+ }
+
+ /**
+ * Draws a random number from the Gaussian distribution
+ * using the Marsaglia polar method.
+ */
+ feature_t next(bool useSpare = true) {
+ feature_t result = _spare;
+ if (_hasSpare && useSpare) {
+ _hasSpare = false;
+ } else {
+ _hasSpare = true;
+
+ feature_t u, v, s;
+ do {
+ u = nextUniform();
+ v = nextUniform();
+ s = u * u + v * v;
+ } while ( (s >= 1.0) || (s == 0.0) );
+ s = std::sqrt(-2.0 * std::log(s) / s);
+
+ _spare = v * s; // saved for next invocation
+ result = u * s;
+ }
+ return result;
+ }
+
+};
+
+} // search
+