diff options
author | Bjørn Christian Seime <bjorncs@yahoo-inc.com> | 2017-05-04 12:34:10 +0000 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahoo-inc.com> | 2017-05-08 12:59:40 +0200 |
commit | bd9d968854e1723b8f09f6d9ba320e9b3daf774e (patch) | |
tree | 0c11f46a573c37b5bf234c14dcfca0af36845db2 /searchlib/src | |
parent | b374b62e10cd3e97a74234117546ba95d29403c1 (diff) |
Add aggregator for calculating the population standard deviation
Diffstat (limited to 'searchlib/src')
11 files changed, 347 insertions, 23 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java new file mode 100644 index 00000000000..4edeec7ba77 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java @@ -0,0 +1,90 @@ +// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.aggregation; + +import com.yahoo.searchlib.expression.FloatResultNode; +import com.yahoo.searchlib.expression.ResultNode; +import com.yahoo.vespa.objects.Deserializer; +import com.yahoo.vespa.objects.ObjectVisitor; +import com.yahoo.vespa.objects.Serializer; + +/** + * @author bjorncs + */ +public class StandardDeviationAggregationResult extends AggregationResult { + public static final int classId = registerClass(0x4000 + 89, StandardDeviationAggregationResult.class); + + private long count; + private double sum; + private double sumOfSquared; + + /** + * Constructor used for deserialization. Will be instantiated with a default sketch. + */ + @SuppressWarnings("unused") + public StandardDeviationAggregationResult() { + this(0, 0.0, 0.0); + } + + public StandardDeviationAggregationResult(long count, double sum, double sumOfSquared) { + this.count = count; + this.sum = sum; + this.sumOfSquared = sumOfSquared; + } + + public double getStandardDeviation() { + if (count == 0) { + return 0; + } else { + double variance = (sumOfSquared - sum * sum / count) / count; + return Math.sqrt(variance); + } + } + + @Override + public ResultNode getRank() { + return new FloatResultNode(getStandardDeviation()); + } + + @Override + protected void onMerge(AggregationResult obj) { + StandardDeviationAggregationResult other = (StandardDeviationAggregationResult) obj; + count += other.count; + sum += other.sum; + sumOfSquared += other.sumOfSquared; + } + + @Override + protected boolean equalsAggregation(AggregationResult obj) { + StandardDeviationAggregationResult other = (StandardDeviationAggregationResult) obj; + return count == this.count && sum == other.sum && sumOfSquared == other.sumOfSquared; + } + + @Override + protected void onSerialize(Serializer buf) { + super.onSerialize(buf); + buf.putLong(null, count); + buf.putDouble(null, sum); + buf.putDouble(null, sumOfSquared); + } + + @Override + protected void onDeserialize(Deserializer buf) { + super.onDeserialize(buf); + count = buf.getLong(null); + sum = buf.getDouble(null); + sumOfSquared = buf.getDouble(null); + } + + @Override + protected int onGetClassId() { + return classId; + } + + @Override + public void visitMembers(ObjectVisitor visitor) { + super.visitMembers(visitor); + visitor.visit("count", count); + visitor.visit("sum", sum); + visitor.visit("sumOfSquared", sumOfSquared); + } +} diff --git a/searchlib/src/test/files/testAggregatorResults b/searchlib/src/test/files/testAggregatorResults Binary files differindex 060b8b86bda..839913ee513 100644 --- a/searchlib/src/test/files/testAggregatorResults +++ b/searchlib/src/test/files/testAggregatorResults diff --git a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java index a9926f7c0e2..c684db11413 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java @@ -155,6 +155,8 @@ public class GroupingSerializationTest { sketch.aggregate(1955583074); t.assertMatch(new ExpressionCountAggregationResult(sketch, s -> 42) .setExpression(new ConstantNode(new IntegerResultNode(67)))); + t.assertMatch(new StandardDeviationAggregationResult(1, 67, 67 * 67) + .setExpression(new ConstantNode(new IntegerResultNode(67)))); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResultTest.java b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResultTest.java new file mode 100644 index 00000000000..9de5f05b092 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResultTest.java @@ -0,0 +1,21 @@ +package com.yahoo.searchlib.aggregation; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class StandardDeviationAggregationResultTest { + + @Test + public void rank_is_standard_deviation() { + StandardDeviationAggregationResult aggregationResult = + new StandardDeviationAggregationResult(3, 131.875, 10595.8); + double rank = aggregationResult.getRank().getFloat(); + assertEquals(aggregationResult.getStandardDeviation(), rank, 0); + assertEquals(40, rank, 0.1); + } + +} diff --git a/searchlib/src/tests/aggregator/perdocexpr.cpp b/searchlib/src/tests/aggregator/perdocexpr.cpp index b68370334c5..6f374f1bea4 100644 --- a/searchlib/src/tests/aggregator/perdocexpr.cpp +++ b/searchlib/src/tests/aggregator/perdocexpr.cpp @@ -46,6 +46,26 @@ void testMin(const ResultNode & a, const ResultNode & b) { ASSERT_TRUE(funcR.getResult().cmp(a) == 0); } +ExpressionNode::UP +createVectorFloat(const std::vector<double> & v) { + std::unique_ptr<FloatResultNodeVector> r = MU<FloatResultNodeVector>(); + r->reserve(v.size()); + for (double d : v) { + r->push_back(FloatResultNode(d)); + } + return MU<ConstantNode>(std::move(r)); +} + +ExpressionNode::UP +createVectorInt(const std::vector<double> & v) { + std::unique_ptr<IntegerResultNodeVector> r = MU<IntegerResultNodeVector>(); + r->reserve(v.size()); + for (double d : v) { + r->push_back(Int64ResultNode(static_cast<int64_t>(d))); + } + return MU<ConstantNode>(std::move(r)); +} + TEST("testMin") { testMin(Int64ResultNode(67), Int64ResultNode(68)); testMin(FloatResultNode(67), FloatResultNode(68)); @@ -155,6 +175,75 @@ TEST("require that expression count estimates rank") { EXPECT_EQUAL(3, func.getRank().getInteger()); } +TEST("require that StandardDeviationAggregationResult can be merged") { + StandardDeviationAggregationResult aggr1; + aggr1.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(8))). + aggregate(DocId(42), HitRank(21)); + + StandardDeviationAggregationResult aggr2; + aggr2.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(10))). + aggregate(DocId(43), HitRank(8)); + + aggr1.merge(aggr2); + EXPECT_EQUAL(2u, aggr1.getCount()); + EXPECT_EQUAL(18.0, aggr1.getSum()); + EXPECT_EQUAL(164.0, aggr1.getSumOfSquared()); +} + +TEST("require that StandardDeviationAggregationResult can be serialized") { + StandardDeviationAggregationResult aggr1; + aggr1.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(8))). + aggregate(DocId(42), HitRank(21)); + + nbostream os; + NBOSerializer nos(os); + nos << aggr1; + Identifiable::UP obj = Identifiable::create(nos); + auto *aggr2 = dynamic_cast<StandardDeviationAggregationResult *>(obj.get()); + ASSERT_TRUE(aggr2); + EXPECT_TRUE(os.empty()); + EXPECT_EQUAL(aggr1.getSumOfSquared(), aggr2->getSumOfSquared()); + EXPECT_EQUAL(aggr1.getSum(), aggr2->getSum()); + EXPECT_EQUAL(aggr1.getCount(), aggr2->getCount()); +} + +TEST("require that StandardDeviationAggregationResult rank is the standard deviation of aggregated values") { + StandardDeviationAggregationResult aggr; + aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(101))). + aggregate(DocId(1), HitRank(21)); + aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(13))). + aggregate(DocId(2), HitRank(8)); + aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(15))). + aggregate(DocId(3), HitRank(30)); + EXPECT_APPROX(41.0203, aggr.getRank().getFloat(), 0.01); +} + +TEST("require that StandardDeviationAggregationResult aggregates multiple expressions correctly") { + StandardDeviationAggregationResult aggr; + aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(1.5))). + aggregate(DocId(1), HitRank(21)); + aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(100.25))). + aggregate(DocId(2), HitRank(8)); + aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(30.125))). + aggregate(DocId(3), HitRank(40)); + + EXPECT_EQUAL(3u, aggr.getCount()); + EXPECT_APPROX(131.875, aggr.getSum(), 0.01); + EXPECT_APPROX(10959.8, aggr.getSumOfSquared(), 0.1); + EXPECT_APPROX(41.5, aggr.getRank().getFloat(), 0.1); +} + +TEST("require that StandardDeviationAggregationResult aggregates multi-value expression correctly") { + StandardDeviationAggregationResult aggr; + aggr.setExpression(createVectorFloat(std::vector<double>({1.5, 100.25, 30.125}))). + aggregate(DocId(42), HitRank(21)); + + EXPECT_EQUAL(3u, aggr.getCount()); + EXPECT_APPROX(131.875, aggr.getSum(), 0.01); + EXPECT_APPROX(10959.8, aggr.getSumOfSquared(), 0.1); + EXPECT_APPROX(41.5, aggr.getRank().getFloat(), 0.1); +} + void testAdd(const ResultNode &a, const ResultNode &b, const ResultNode &c) { AddFunctionNode func; func.appendArg(MU<ConstantNode>(ResultNode::UP(a.clone()))) @@ -1006,22 +1095,7 @@ void testModulo(ExpressionNode::UP arg1, ExpressionNode::UP arg2, testArith(add, std::move(arg1), std::move(arg2), intResult, floatResult); } -ExpressionNode::UP -createVectorInt(const std::vector<double> & v) { - std::unique_ptr<IntegerResultNodeVector> r = MU<IntegerResultNodeVector>(); - for (double d : v) { - r->push_back(Int64ResultNode(static_cast<int64_t>(d))); - } - return MU<ConstantNode>(std::move(r)); -} -ExpressionNode::UP -createVectorFloat(const std::vector<double> & v) { - std::unique_ptr<FloatResultNodeVector> r = MU<FloatResultNodeVector>(); - for (double d : v) { - r->push_back(FloatResultNode(d)); - } - return MU<ConstantNode>(std::move(r)); -} + void testArithmeticArguments(NumericFunctionNode &function, const std::vector<double> & arg1, @@ -1516,6 +1590,7 @@ TEST("testStreamingAll") { testStreaming(RawResultNode("Tester RawResultNode streaming", 30)); testStreaming(CountAggregationResult()); testStreaming(ExpressionCountAggregationResult()); + testStreaming(StandardDeviationAggregationResult()); testStreaming(SumAggregationResult()); testStreaming(MinAggregationResult()); testStreaming(MaxAggregationResult()); diff --git a/searchlib/src/tests/grouping/grouping_serialization_test.cpp b/searchlib/src/tests/grouping/grouping_serialization_test.cpp index 8592995915c..734c9095d48 100644 --- a/searchlib/src/tests/grouping/grouping_serialization_test.cpp +++ b/searchlib/src/tests/grouping/grouping_serialization_test.cpp @@ -229,6 +229,10 @@ TEST_F("testAggregatorResults", Fixture("testAggregatorResults")) { expression_count.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(67))) .aggregate(DocId(42), HitRank(21)); f.checkObject(expression_count); + StandardDeviationAggregationResult stddev; + stddev.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(67))) + .aggregate(DocId(42), HitRank(21)); + f.checkObject(stddev); } TEST_F("testHitCollection", Fixture("testHitCollection")) { diff --git a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp index 1cfc9804390..9ccad83df19 100644 --- a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp +++ b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp @@ -37,6 +37,7 @@ IMPLEMENT_AGGREGATIONRESULT(MinAggregationResult, AggregationResult); IMPLEMENT_AGGREGATIONRESULT(AverageAggregationResult, AggregationResult); IMPLEMENT_AGGREGATIONRESULT(XorAggregationResult, AggregationResult); IMPLEMENT_AGGREGATIONRESULT(ExpressionCountAggregationResult, AggregationResult); +IMPLEMENT_AGGREGATIONRESULT(StandardDeviationAggregationResult, AggregationResult); AggregationResult::AggregationResult() : _expressionTree(new ExpressionTree()), @@ -501,6 +502,82 @@ Deserializer &ExpressionCountAggregationResult::onDeserialize( ExpressionCountAggregationResult::ExpressionCountAggregationResult() : AggregationResult(), _hll() { } ExpressionCountAggregationResult::~ExpressionCountAggregationResult() {} +StandardDeviationAggregationResult::StandardDeviationAggregationResult() + : AggregationResult(), _count(), _sum(), _sumOfSquared(), _stdDevScratchPad() +{ + _stdDevScratchPad.reset(new expression::FloatResultNode()); +} + +StandardDeviationAggregationResult::~StandardDeviationAggregationResult() {} + +const NumericResultNode& StandardDeviationAggregationResult::getStandardDeviation() const noexcept +{ + if (_count == 0) { + _stdDevScratchPad->set(Int64ResultNode(0)); + } else { + double variance = (_sumOfSquared.getFloat() - _sum.getFloat() * _sum.getFloat() / _count) / _count; + double stddev = std::sqrt(variance); + _stdDevScratchPad->set(FloatResultNode(stddev)); + } + return *_stdDevScratchPad; +} + +void StandardDeviationAggregationResult::onMerge(const AggregationResult &r) { + const StandardDeviationAggregationResult &result = + Identifiable::cast<const StandardDeviationAggregationResult &>(r); + _count += result._count; + _sum.add(result._sum); + _sumOfSquared.add(result._sumOfSquared); +} + +void StandardDeviationAggregationResult::onAggregate(const ResultNode &result) { + if (result.isMultiValue()) { + static_cast<const ResultNodeVector &>(result).flattenSum(_sum); + static_cast<const ResultNodeVector &>(result).flattenSumOfSquared(_sumOfSquared); + _count += static_cast<const ResultNodeVector &>(result).size(); + } else { + _sum.add(result); + FloatResultNode squared(result.getFloat()); + squared.multiply(result); + _sumOfSquared.add(squared); + _count++; + } +} + +void StandardDeviationAggregationResult::onReset() +{ + _count = 0; + _sum.set(0.0); + _sumOfSquared.set(0.0); +} + +Serializer & StandardDeviationAggregationResult::onSerialize(Serializer & os) const +{ + AggregationResult::onSerialize(os); + double sum = _sum.getFloat(); + double sumOfSquared = _sumOfSquared.getFloat(); + return os << _count << sum << sumOfSquared; +} + +Deserializer & StandardDeviationAggregationResult::onDeserialize(Deserializer & is) +{ + AggregationResult::onDeserialize(is); + double sum; + double sumOfSquared; + auto& r = is >> _count >> sum >> sumOfSquared; + _sum.set(sum); + _sumOfSquared.set(sumOfSquared); + return r; +} + +void StandardDeviationAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const +{ + AggregationResult::visitMembers(visitor); + visit(visitor, "count", _count); + visit(visitor, "sum", _sum); + visit(visitor, "sumOfSquared", _sumOfSquared); +} + } // namespace aggregation } // namespace search diff --git a/searchlib/src/vespa/searchlib/aggregation/aggregation.h b/searchlib/src/vespa/searchlib/aggregation/aggregation.h index e17ee88c113..ad5d48b35b8 100644 --- a/searchlib/src/vespa/searchlib/aggregation/aggregation.h +++ b/searchlib/src/vespa/searchlib/aggregation/aggregation.h @@ -11,6 +11,7 @@ #include <vespa/searchlib/aggregation/averageaggregationresult.h> #include <vespa/searchlib/aggregation/xoraggregationresult.h> #include <vespa/searchlib/aggregation/hitsaggregationresult.h> +#include <vespa/searchlib/aggregation/standarddeviationaggregationresult.h> #include <vespa/searchlib/aggregation/grouping.h> namespace search { diff --git a/searchlib/src/vespa/searchlib/aggregation/standarddeviationaggregationresult.h b/searchlib/src/vespa/searchlib/aggregation/standarddeviationaggregationresult.h new file mode 100644 index 00000000000..283637d5f0b --- /dev/null +++ b/searchlib/src/vespa/searchlib/aggregation/standarddeviationaggregationresult.h @@ -0,0 +1,35 @@ +// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "aggregationresult.h" +#include <vespa/searchlib/expression/floatresultnode.h> +#include <vespa/searchlib/expression/integerresultnode.h> + + +namespace search::aggregation { + +// Aggregator that calculates the population standard deviation +class StandardDeviationAggregationResult : public AggregationResult +{ +public: + DECLARE_AGGREGATIONRESULT(StandardDeviationAggregationResult); + StandardDeviationAggregationResult(); + ~StandardDeviationAggregationResult(); + + void visitMembers(vespalib::ObjectVisitor &visitor) const override; + double getSum() const noexcept { return _sum.getFloat(); } + double getSumOfSquared() const noexcept { return _sumOfSquared.getFloat(); } + uint64_t getCount() const noexcept { return _count; } +private: + const ResultNode& onGetRank() const noexcept override { return getStandardDeviation(); } + void onPrepare(const ResultNode&, bool) override { }; + const expression::NumericResultNode& getStandardDeviation() const noexcept; + + uint64_t _count; + expression::FloatResultNode _sum; + expression::FloatResultNode _sumOfSquared; + mutable expression::FloatResultNode::CP _stdDevScratchPad; +}; + +} + diff --git a/searchlib/src/vespa/searchlib/common/identifiable.h b/searchlib/src/vespa/searchlib/common/identifiable.h index a2a2dfdb7bb..5ee131ccd56 100644 --- a/searchlib/src/vespa/searchlib/common/identifiable.h +++ b/searchlib/src/vespa/searchlib/common/identifiable.h @@ -90,6 +90,8 @@ #define CID_search_aggregation_HitsAggregationResult SEARCHLIB_CID(87) #define CID_search_aggregation_ExpressionCountAggregationResult \ SEARCHLIB_CID(88) +#define CID_search_aggregation_StandardDeviationAggregationResult \ + SEARCHLIB_CID(89) #define CID_search_aggregation_Group SEARCHLIB_CID(90) #define CID_search_aggregation_Grouping SEARCHLIB_CID(91) diff --git a/searchlib/src/vespa/searchlib/expression/resultvector.h b/searchlib/src/vespa/searchlib/expression/resultvector.h index 9d6b409449e..f91354ce771 100644 --- a/searchlib/src/vespa/searchlib/expression/resultvector.h +++ b/searchlib/src/vespa/searchlib/expression/resultvector.h @@ -31,6 +31,7 @@ public: virtual ResultNode & get(size_t index) = 0; virtual void clear() = 0; virtual void resize(size_t sz) = 0; + virtual void reserve(size_t sz) = 0; size_t size() const { return onSize(); } bool empty() const { return size() == 0; } /** @@ -44,6 +45,7 @@ public: virtual ResultNode & flattenAnd(ResultNode & r) const { return r; } virtual ResultNode & flattenOr(ResultNode & r) const { return r; } virtual ResultNode & flattenXor(ResultNode & r) const { return r; } + virtual ResultNode & flattenSumOfSquared(ResultNode & r) const { return r; } virtual void min(const ResultNode & b) { (void) b; } virtual void max(const ResultNode & b) { (void) b; } virtual void add(const ResultNode & b) { (void) b; } @@ -92,6 +94,7 @@ public: ResultNode & get(size_t index) override { return _result[index]; } void clear() override { _result.clear(); } void resize(size_t sz) override { _result.resize(sz); } + void reserve(size_t sz) override { _result.reserve(sz); } void negate() override; private: void visitMembers(vespalib::ObjectVisitor &visitor) const override { visit(visitor, "Vector", _result); } @@ -219,7 +222,7 @@ public: B v; v.set(r); const std::vector<B> & vec(this->getVector()); - for(size_t i(0), m(vec.size()); i < m; i++) { + for (size_t i(0), m(vec.size()); i < m; i++) { v.multiply(vec[i]); } r.set(v); @@ -229,7 +232,7 @@ public: Int64ResultNode v; v.set(r); const std::vector<B> & vec(this->getVector()); - for(size_t i(0), m(vec.size()); i < m; i++) { + for (size_t i(0), m(vec.size()); i < m; i++) { v.andOp(vec[i]); } r.set(v); @@ -239,7 +242,7 @@ public: Int64ResultNode v; v.set(r); const std::vector<B> & vec(this->getVector()); - for(size_t i(0), m(vec.size()); i < m; i++) { + for (size_t i(0), m(vec.size()); i < m; i++) { v.orOp(vec[i]); } r.set(v); @@ -249,7 +252,7 @@ public: Int64ResultNode v; v.set(r); const std::vector<B> & vec(this->getVector()); - for(size_t i(0), m(vec.size()); i < m; i++) { + for (size_t i(0), m(vec.size()); i < m; i++) { v.xorOp(vec[i]); } r.set(v); @@ -259,7 +262,7 @@ public: B v; v.set(r); const std::vector<B> & vec(this->getVector()); - for(size_t i(0), m(vec.size()); i < m; i++) { + for (size_t i(0), m(vec.size()); i < m; i++) { v.add(vec[i]); } r.set(v); @@ -269,7 +272,7 @@ public: B v; v.set(r); const std::vector<B> & vec(this->getVector()); - for(size_t i(0), m(vec.size()); i < m; i++) { + for (size_t i(0), m(vec.size()); i < m; i++) { v.max(vec[i]); } r.set(v); @@ -279,12 +282,25 @@ public: B v; v.set(r); const std::vector<B> & vec(this->getVector()); - for(size_t i(0), m(vec.size()); i < m; i++) { + for (size_t i(0), m(vec.size()); i < m; i++) { v.min(vec[i]); } r.set(v); return r; } + ResultNode & flattenSumOfSquared(ResultNode & r) const override { + B v; + v.set(r); + const std::vector<B> & vec(this->getVector()); + for (size_t i(0), m(vec.size()); i < m; i++) { + B squared; + squared.set(vec[i]); + squared.multiply(vec[i]); + v.add(squared); + } + r.set(v); + return r; + } }; @@ -400,6 +416,7 @@ public: ResultNode & get(size_t index) override { return *_v[index]; } void clear() override { _v.clear(); } void resize(size_t sz) override { _v.resize(sz); } + void reserve(size_t sz) override { _v.reserve(sz); } private: int64_t onGetInteger(size_t index) const override { return _v[index]->getInteger(index); } double onGetFloat(size_t index) const override { return _v[index]->getFloat(index); } |