diff options
author | Bjørn Christian Seime <bjorncs@yahoo-inc.com> | 2017-05-04 12:34:10 +0000 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahoo-inc.com> | 2017-05-08 12:59:40 +0200 |
commit | bd9d968854e1723b8f09f6d9ba320e9b3daf774e (patch) | |
tree | 0c11f46a573c37b5bf234c14dcfca0af36845db2 /searchlib/src/tests/aggregator | |
parent | b374b62e10cd3e97a74234117546ba95d29403c1 (diff) |
Add aggregator for calculating the population standard deviation
Diffstat (limited to 'searchlib/src/tests/aggregator')
-rw-r--r-- | searchlib/src/tests/aggregator/perdocexpr.cpp | 107 |
1 files changed, 91 insertions, 16 deletions
diff --git a/searchlib/src/tests/aggregator/perdocexpr.cpp b/searchlib/src/tests/aggregator/perdocexpr.cpp index b68370334c5..6f374f1bea4 100644 --- a/searchlib/src/tests/aggregator/perdocexpr.cpp +++ b/searchlib/src/tests/aggregator/perdocexpr.cpp @@ -46,6 +46,26 @@ void testMin(const ResultNode & a, const ResultNode & b) { ASSERT_TRUE(funcR.getResult().cmp(a) == 0); } +ExpressionNode::UP +createVectorFloat(const std::vector<double> & v) { + std::unique_ptr<FloatResultNodeVector> r = MU<FloatResultNodeVector>(); + r->reserve(v.size()); + for (double d : v) { + r->push_back(FloatResultNode(d)); + } + return MU<ConstantNode>(std::move(r)); +} + +ExpressionNode::UP +createVectorInt(const std::vector<double> & v) { + std::unique_ptr<IntegerResultNodeVector> r = MU<IntegerResultNodeVector>(); + r->reserve(v.size()); + for (double d : v) { + r->push_back(Int64ResultNode(static_cast<int64_t>(d))); + } + return MU<ConstantNode>(std::move(r)); +} + TEST("testMin") { testMin(Int64ResultNode(67), Int64ResultNode(68)); testMin(FloatResultNode(67), FloatResultNode(68)); @@ -155,6 +175,75 @@ TEST("require that expression count estimates rank") { EXPECT_EQUAL(3, func.getRank().getInteger()); } +TEST("require that StandardDeviationAggregationResult can be merged") { + StandardDeviationAggregationResult aggr1; + aggr1.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(8))). + aggregate(DocId(42), HitRank(21)); + + StandardDeviationAggregationResult aggr2; + aggr2.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(10))). + aggregate(DocId(43), HitRank(8)); + + aggr1.merge(aggr2); + EXPECT_EQUAL(2u, aggr1.getCount()); + EXPECT_EQUAL(18.0, aggr1.getSum()); + EXPECT_EQUAL(164.0, aggr1.getSumOfSquared()); +} + +TEST("require that StandardDeviationAggregationResult can be serialized") { + StandardDeviationAggregationResult aggr1; + aggr1.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(8))). + aggregate(DocId(42), HitRank(21)); + + nbostream os; + NBOSerializer nos(os); + nos << aggr1; + Identifiable::UP obj = Identifiable::create(nos); + auto *aggr2 = dynamic_cast<StandardDeviationAggregationResult *>(obj.get()); + ASSERT_TRUE(aggr2); + EXPECT_TRUE(os.empty()); + EXPECT_EQUAL(aggr1.getSumOfSquared(), aggr2->getSumOfSquared()); + EXPECT_EQUAL(aggr1.getSum(), aggr2->getSum()); + EXPECT_EQUAL(aggr1.getCount(), aggr2->getCount()); +} + +TEST("require that StandardDeviationAggregationResult rank is the standard deviation of aggregated values") { + StandardDeviationAggregationResult aggr; + aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(101))). + aggregate(DocId(1), HitRank(21)); + aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(13))). + aggregate(DocId(2), HitRank(8)); + aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(15))). + aggregate(DocId(3), HitRank(30)); + EXPECT_APPROX(41.0203, aggr.getRank().getFloat(), 0.01); +} + +TEST("require that StandardDeviationAggregationResult aggregates multiple expressions correctly") { + StandardDeviationAggregationResult aggr; + aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(1.5))). + aggregate(DocId(1), HitRank(21)); + aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(100.25))). + aggregate(DocId(2), HitRank(8)); + aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(30.125))). + aggregate(DocId(3), HitRank(40)); + + EXPECT_EQUAL(3u, aggr.getCount()); + EXPECT_APPROX(131.875, aggr.getSum(), 0.01); + EXPECT_APPROX(10959.8, aggr.getSumOfSquared(), 0.1); + EXPECT_APPROX(41.5, aggr.getRank().getFloat(), 0.1); +} + +TEST("require that StandardDeviationAggregationResult aggregates multi-value expression correctly") { + StandardDeviationAggregationResult aggr; + aggr.setExpression(createVectorFloat(std::vector<double>({1.5, 100.25, 30.125}))). + aggregate(DocId(42), HitRank(21)); + + EXPECT_EQUAL(3u, aggr.getCount()); + EXPECT_APPROX(131.875, aggr.getSum(), 0.01); + EXPECT_APPROX(10959.8, aggr.getSumOfSquared(), 0.1); + EXPECT_APPROX(41.5, aggr.getRank().getFloat(), 0.1); +} + void testAdd(const ResultNode &a, const ResultNode &b, const ResultNode &c) { AddFunctionNode func; func.appendArg(MU<ConstantNode>(ResultNode::UP(a.clone()))) @@ -1006,22 +1095,7 @@ void testModulo(ExpressionNode::UP arg1, ExpressionNode::UP arg2, testArith(add, std::move(arg1), std::move(arg2), intResult, floatResult); } -ExpressionNode::UP -createVectorInt(const std::vector<double> & v) { - std::unique_ptr<IntegerResultNodeVector> r = MU<IntegerResultNodeVector>(); - for (double d : v) { - r->push_back(Int64ResultNode(static_cast<int64_t>(d))); - } - return MU<ConstantNode>(std::move(r)); -} -ExpressionNode::UP -createVectorFloat(const std::vector<double> & v) { - std::unique_ptr<FloatResultNodeVector> r = MU<FloatResultNodeVector>(); - for (double d : v) { - r->push_back(FloatResultNode(d)); - } - return MU<ConstantNode>(std::move(r)); -} + void testArithmeticArguments(NumericFunctionNode &function, const std::vector<double> & arg1, @@ -1516,6 +1590,7 @@ TEST("testStreamingAll") { testStreaming(RawResultNode("Tester RawResultNode streaming", 30)); testStreaming(CountAggregationResult()); testStreaming(ExpressionCountAggregationResult()); + testStreaming(StandardDeviationAggregationResult()); testStreaming(SumAggregationResult()); testStreaming(MinAggregationResult()); testStreaming(MaxAggregationResult()); |