summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/aggregator
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahoo-inc.com>2017-05-04 12:34:10 +0000
committerBjørn Christian Seime <bjorncs@yahoo-inc.com>2017-05-08 12:59:40 +0200
commitbd9d968854e1723b8f09f6d9ba320e9b3daf774e (patch)
tree0c11f46a573c37b5bf234c14dcfca0af36845db2 /searchlib/src/tests/aggregator
parentb374b62e10cd3e97a74234117546ba95d29403c1 (diff)
Add aggregator for calculating the population standard deviation
Diffstat (limited to 'searchlib/src/tests/aggregator')
-rw-r--r--searchlib/src/tests/aggregator/perdocexpr.cpp107
1 files changed, 91 insertions, 16 deletions
diff --git a/searchlib/src/tests/aggregator/perdocexpr.cpp b/searchlib/src/tests/aggregator/perdocexpr.cpp
index b68370334c5..6f374f1bea4 100644
--- a/searchlib/src/tests/aggregator/perdocexpr.cpp
+++ b/searchlib/src/tests/aggregator/perdocexpr.cpp
@@ -46,6 +46,26 @@ void testMin(const ResultNode & a, const ResultNode & b) {
ASSERT_TRUE(funcR.getResult().cmp(a) == 0);
}
+ExpressionNode::UP
+createVectorFloat(const std::vector<double> & v) {
+ std::unique_ptr<FloatResultNodeVector> r = MU<FloatResultNodeVector>();
+ r->reserve(v.size());
+ for (double d : v) {
+ r->push_back(FloatResultNode(d));
+ }
+ return MU<ConstantNode>(std::move(r));
+}
+
+ExpressionNode::UP
+createVectorInt(const std::vector<double> & v) {
+ std::unique_ptr<IntegerResultNodeVector> r = MU<IntegerResultNodeVector>();
+ r->reserve(v.size());
+ for (double d : v) {
+ r->push_back(Int64ResultNode(static_cast<int64_t>(d)));
+ }
+ return MU<ConstantNode>(std::move(r));
+}
+
TEST("testMin") {
testMin(Int64ResultNode(67), Int64ResultNode(68));
testMin(FloatResultNode(67), FloatResultNode(68));
@@ -155,6 +175,75 @@ TEST("require that expression count estimates rank") {
EXPECT_EQUAL(3, func.getRank().getInteger());
}
+TEST("require that StandardDeviationAggregationResult can be merged") {
+ StandardDeviationAggregationResult aggr1;
+ aggr1.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(8))).
+ aggregate(DocId(42), HitRank(21));
+
+ StandardDeviationAggregationResult aggr2;
+ aggr2.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(10))).
+ aggregate(DocId(43), HitRank(8));
+
+ aggr1.merge(aggr2);
+ EXPECT_EQUAL(2u, aggr1.getCount());
+ EXPECT_EQUAL(18.0, aggr1.getSum());
+ EXPECT_EQUAL(164.0, aggr1.getSumOfSquared());
+}
+
+TEST("require that StandardDeviationAggregationResult can be serialized") {
+ StandardDeviationAggregationResult aggr1;
+ aggr1.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(8))).
+ aggregate(DocId(42), HitRank(21));
+
+ nbostream os;
+ NBOSerializer nos(os);
+ nos << aggr1;
+ Identifiable::UP obj = Identifiable::create(nos);
+ auto *aggr2 = dynamic_cast<StandardDeviationAggregationResult *>(obj.get());
+ ASSERT_TRUE(aggr2);
+ EXPECT_TRUE(os.empty());
+ EXPECT_EQUAL(aggr1.getSumOfSquared(), aggr2->getSumOfSquared());
+ EXPECT_EQUAL(aggr1.getSum(), aggr2->getSum());
+ EXPECT_EQUAL(aggr1.getCount(), aggr2->getCount());
+}
+
+TEST("require that StandardDeviationAggregationResult rank is the standard deviation of aggregated values") {
+ StandardDeviationAggregationResult aggr;
+ aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(101))).
+ aggregate(DocId(1), HitRank(21));
+ aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(13))).
+ aggregate(DocId(2), HitRank(8));
+ aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(15))).
+ aggregate(DocId(3), HitRank(30));
+ EXPECT_APPROX(41.0203, aggr.getRank().getFloat(), 0.01);
+}
+
+TEST("require that StandardDeviationAggregationResult aggregates multiple expressions correctly") {
+ StandardDeviationAggregationResult aggr;
+ aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(1.5))).
+ aggregate(DocId(1), HitRank(21));
+ aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(100.25))).
+ aggregate(DocId(2), HitRank(8));
+ aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(30.125))).
+ aggregate(DocId(3), HitRank(40));
+
+ EXPECT_EQUAL(3u, aggr.getCount());
+ EXPECT_APPROX(131.875, aggr.getSum(), 0.01);
+ EXPECT_APPROX(10959.8, aggr.getSumOfSquared(), 0.1);
+ EXPECT_APPROX(41.5, aggr.getRank().getFloat(), 0.1);
+}
+
+TEST("require that StandardDeviationAggregationResult aggregates multi-value expression correctly") {
+ StandardDeviationAggregationResult aggr;
+ aggr.setExpression(createVectorFloat(std::vector<double>({1.5, 100.25, 30.125}))).
+ aggregate(DocId(42), HitRank(21));
+
+ EXPECT_EQUAL(3u, aggr.getCount());
+ EXPECT_APPROX(131.875, aggr.getSum(), 0.01);
+ EXPECT_APPROX(10959.8, aggr.getSumOfSquared(), 0.1);
+ EXPECT_APPROX(41.5, aggr.getRank().getFloat(), 0.1);
+}
+
void testAdd(const ResultNode &a, const ResultNode &b, const ResultNode &c) {
AddFunctionNode func;
func.appendArg(MU<ConstantNode>(ResultNode::UP(a.clone())))
@@ -1006,22 +1095,7 @@ void testModulo(ExpressionNode::UP arg1, ExpressionNode::UP arg2,
testArith(add, std::move(arg1), std::move(arg2), intResult, floatResult);
}
-ExpressionNode::UP
-createVectorInt(const std::vector<double> & v) {
- std::unique_ptr<IntegerResultNodeVector> r = MU<IntegerResultNodeVector>();
- for (double d : v) {
- r->push_back(Int64ResultNode(static_cast<int64_t>(d)));
- }
- return MU<ConstantNode>(std::move(r));
-}
-ExpressionNode::UP
-createVectorFloat(const std::vector<double> & v) {
- std::unique_ptr<FloatResultNodeVector> r = MU<FloatResultNodeVector>();
- for (double d : v) {
- r->push_back(FloatResultNode(d));
- }
- return MU<ConstantNode>(std::move(r));
-}
+
void testArithmeticArguments(NumericFunctionNode &function,
const std::vector<double> & arg1,
@@ -1516,6 +1590,7 @@ TEST("testStreamingAll") {
testStreaming(RawResultNode("Tester RawResultNode streaming", 30));
testStreaming(CountAggregationResult());
testStreaming(ExpressionCountAggregationResult());
+ testStreaming(StandardDeviationAggregationResult());
testStreaming(SumAggregationResult());
testStreaming(MinAggregationResult());
testStreaming(MaxAggregationResult());