summaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahoo-inc.com>2017-05-04 12:34:10 +0000
committerBjørn Christian Seime <bjorncs@yahoo-inc.com>2017-05-08 12:59:40 +0200
commitbd9d968854e1723b8f09f6d9ba320e9b3daf774e (patch)
tree0c11f46a573c37b5bf234c14dcfca0af36845db2 /searchlib/src
parentb374b62e10cd3e97a74234117546ba95d29403c1 (diff)
Add aggregator for calculating the population standard deviation
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java90
-rw-r--r--searchlib/src/test/files/testAggregatorResultsbin310 -> 364 bytes
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResultTest.java21
-rw-r--r--searchlib/src/tests/aggregator/perdocexpr.cpp107
-rw-r--r--searchlib/src/tests/grouping/grouping_serialization_test.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/aggregation/aggregation.cpp77
-rw-r--r--searchlib/src/vespa/searchlib/aggregation/aggregation.h1
-rw-r--r--searchlib/src/vespa/searchlib/aggregation/standarddeviationaggregationresult.h35
-rw-r--r--searchlib/src/vespa/searchlib/common/identifiable.h2
-rw-r--r--searchlib/src/vespa/searchlib/expression/resultvector.h31
11 files changed, 347 insertions, 23 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java
new file mode 100644
index 00000000000..4edeec7ba77
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java
@@ -0,0 +1,90 @@
+// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.aggregation;
+
+import com.yahoo.searchlib.expression.FloatResultNode;
+import com.yahoo.searchlib.expression.ResultNode;
+import com.yahoo.vespa.objects.Deserializer;
+import com.yahoo.vespa.objects.ObjectVisitor;
+import com.yahoo.vespa.objects.Serializer;
+
+/**
+ * @author bjorncs
+ */
+public class StandardDeviationAggregationResult extends AggregationResult {
+ public static final int classId = registerClass(0x4000 + 89, StandardDeviationAggregationResult.class);
+
+ private long count;
+ private double sum;
+ private double sumOfSquared;
+
+ /**
+ * Constructor used for deserialization. Will be instantiated with a default sketch.
+ */
+ @SuppressWarnings("unused")
+ public StandardDeviationAggregationResult() {
+ this(0, 0.0, 0.0);
+ }
+
+ public StandardDeviationAggregationResult(long count, double sum, double sumOfSquared) {
+ this.count = count;
+ this.sum = sum;
+ this.sumOfSquared = sumOfSquared;
+ }
+
+ public double getStandardDeviation() {
+ if (count == 0) {
+ return 0;
+ } else {
+ double variance = (sumOfSquared - sum * sum / count) / count;
+ return Math.sqrt(variance);
+ }
+ }
+
+ @Override
+ public ResultNode getRank() {
+ return new FloatResultNode(getStandardDeviation());
+ }
+
+ @Override
+ protected void onMerge(AggregationResult obj) {
+ StandardDeviationAggregationResult other = (StandardDeviationAggregationResult) obj;
+ count += other.count;
+ sum += other.sum;
+ sumOfSquared += other.sumOfSquared;
+ }
+
+ @Override
+ protected boolean equalsAggregation(AggregationResult obj) {
+ StandardDeviationAggregationResult other = (StandardDeviationAggregationResult) obj;
+ return count == this.count && sum == other.sum && sumOfSquared == other.sumOfSquared;
+ }
+
+ @Override
+ protected void onSerialize(Serializer buf) {
+ super.onSerialize(buf);
+ buf.putLong(null, count);
+ buf.putDouble(null, sum);
+ buf.putDouble(null, sumOfSquared);
+ }
+
+ @Override
+ protected void onDeserialize(Deserializer buf) {
+ super.onDeserialize(buf);
+ count = buf.getLong(null);
+ sum = buf.getDouble(null);
+ sumOfSquared = buf.getDouble(null);
+ }
+
+ @Override
+ protected int onGetClassId() {
+ return classId;
+ }
+
+ @Override
+ public void visitMembers(ObjectVisitor visitor) {
+ super.visitMembers(visitor);
+ visitor.visit("count", count);
+ visitor.visit("sum", sum);
+ visitor.visit("sumOfSquared", sumOfSquared);
+ }
+}
diff --git a/searchlib/src/test/files/testAggregatorResults b/searchlib/src/test/files/testAggregatorResults
index 060b8b86bda..839913ee513 100644
--- a/searchlib/src/test/files/testAggregatorResults
+++ b/searchlib/src/test/files/testAggregatorResults
Binary files differ
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
index a9926f7c0e2..c684db11413 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
@@ -155,6 +155,8 @@ public class GroupingSerializationTest {
sketch.aggregate(1955583074);
t.assertMatch(new ExpressionCountAggregationResult(sketch, s -> 42)
.setExpression(new ConstantNode(new IntegerResultNode(67))));
+ t.assertMatch(new StandardDeviationAggregationResult(1, 67, 67 * 67)
+ .setExpression(new ConstantNode(new IntegerResultNode(67))));
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResultTest.java b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResultTest.java
new file mode 100644
index 00000000000..9de5f05b092
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResultTest.java
@@ -0,0 +1,21 @@
+package com.yahoo.searchlib.aggregation;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bjorncs
+ */
+public class StandardDeviationAggregationResultTest {
+
+ @Test
+ public void rank_is_standard_deviation() {
+ StandardDeviationAggregationResult aggregationResult =
+ new StandardDeviationAggregationResult(3, 131.875, 10595.8);
+ double rank = aggregationResult.getRank().getFloat();
+ assertEquals(aggregationResult.getStandardDeviation(), rank, 0);
+ assertEquals(40, rank, 0.1);
+ }
+
+}
diff --git a/searchlib/src/tests/aggregator/perdocexpr.cpp b/searchlib/src/tests/aggregator/perdocexpr.cpp
index b68370334c5..6f374f1bea4 100644
--- a/searchlib/src/tests/aggregator/perdocexpr.cpp
+++ b/searchlib/src/tests/aggregator/perdocexpr.cpp
@@ -46,6 +46,26 @@ void testMin(const ResultNode & a, const ResultNode & b) {
ASSERT_TRUE(funcR.getResult().cmp(a) == 0);
}
+ExpressionNode::UP
+createVectorFloat(const std::vector<double> & v) {
+ std::unique_ptr<FloatResultNodeVector> r = MU<FloatResultNodeVector>();
+ r->reserve(v.size());
+ for (double d : v) {
+ r->push_back(FloatResultNode(d));
+ }
+ return MU<ConstantNode>(std::move(r));
+}
+
+ExpressionNode::UP
+createVectorInt(const std::vector<double> & v) {
+ std::unique_ptr<IntegerResultNodeVector> r = MU<IntegerResultNodeVector>();
+ r->reserve(v.size());
+ for (double d : v) {
+ r->push_back(Int64ResultNode(static_cast<int64_t>(d)));
+ }
+ return MU<ConstantNode>(std::move(r));
+}
+
TEST("testMin") {
testMin(Int64ResultNode(67), Int64ResultNode(68));
testMin(FloatResultNode(67), FloatResultNode(68));
@@ -155,6 +175,75 @@ TEST("require that expression count estimates rank") {
EXPECT_EQUAL(3, func.getRank().getInteger());
}
+TEST("require that StandardDeviationAggregationResult can be merged") {
+ StandardDeviationAggregationResult aggr1;
+ aggr1.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(8))).
+ aggregate(DocId(42), HitRank(21));
+
+ StandardDeviationAggregationResult aggr2;
+ aggr2.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(10))).
+ aggregate(DocId(43), HitRank(8));
+
+ aggr1.merge(aggr2);
+ EXPECT_EQUAL(2u, aggr1.getCount());
+ EXPECT_EQUAL(18.0, aggr1.getSum());
+ EXPECT_EQUAL(164.0, aggr1.getSumOfSquared());
+}
+
+TEST("require that StandardDeviationAggregationResult can be serialized") {
+ StandardDeviationAggregationResult aggr1;
+ aggr1.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(8))).
+ aggregate(DocId(42), HitRank(21));
+
+ nbostream os;
+ NBOSerializer nos(os);
+ nos << aggr1;
+ Identifiable::UP obj = Identifiable::create(nos);
+ auto *aggr2 = dynamic_cast<StandardDeviationAggregationResult *>(obj.get());
+ ASSERT_TRUE(aggr2);
+ EXPECT_TRUE(os.empty());
+ EXPECT_EQUAL(aggr1.getSumOfSquared(), aggr2->getSumOfSquared());
+ EXPECT_EQUAL(aggr1.getSum(), aggr2->getSum());
+ EXPECT_EQUAL(aggr1.getCount(), aggr2->getCount());
+}
+
+TEST("require that StandardDeviationAggregationResult rank is the standard deviation of aggregated values") {
+ StandardDeviationAggregationResult aggr;
+ aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(101))).
+ aggregate(DocId(1), HitRank(21));
+ aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(13))).
+ aggregate(DocId(2), HitRank(8));
+ aggr.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(15))).
+ aggregate(DocId(3), HitRank(30));
+ EXPECT_APPROX(41.0203, aggr.getRank().getFloat(), 0.01);
+}
+
+TEST("require that StandardDeviationAggregationResult aggregates multiple expressions correctly") {
+ StandardDeviationAggregationResult aggr;
+ aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(1.5))).
+ aggregate(DocId(1), HitRank(21));
+ aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(100.25))).
+ aggregate(DocId(2), HitRank(8));
+ aggr.setExpression(MU<ConstantNode>(MU<FloatResultNode>(30.125))).
+ aggregate(DocId(3), HitRank(40));
+
+ EXPECT_EQUAL(3u, aggr.getCount());
+ EXPECT_APPROX(131.875, aggr.getSum(), 0.01);
+ EXPECT_APPROX(10959.8, aggr.getSumOfSquared(), 0.1);
+ EXPECT_APPROX(41.5, aggr.getRank().getFloat(), 0.1);
+}
+
+TEST("require that StandardDeviationAggregationResult aggregates multi-value expression correctly") {
+ StandardDeviationAggregationResult aggr;
+ aggr.setExpression(createVectorFloat(std::vector<double>({1.5, 100.25, 30.125}))).
+ aggregate(DocId(42), HitRank(21));
+
+ EXPECT_EQUAL(3u, aggr.getCount());
+ EXPECT_APPROX(131.875, aggr.getSum(), 0.01);
+ EXPECT_APPROX(10959.8, aggr.getSumOfSquared(), 0.1);
+ EXPECT_APPROX(41.5, aggr.getRank().getFloat(), 0.1);
+}
+
void testAdd(const ResultNode &a, const ResultNode &b, const ResultNode &c) {
AddFunctionNode func;
func.appendArg(MU<ConstantNode>(ResultNode::UP(a.clone())))
@@ -1006,22 +1095,7 @@ void testModulo(ExpressionNode::UP arg1, ExpressionNode::UP arg2,
testArith(add, std::move(arg1), std::move(arg2), intResult, floatResult);
}
-ExpressionNode::UP
-createVectorInt(const std::vector<double> & v) {
- std::unique_ptr<IntegerResultNodeVector> r = MU<IntegerResultNodeVector>();
- for (double d : v) {
- r->push_back(Int64ResultNode(static_cast<int64_t>(d)));
- }
- return MU<ConstantNode>(std::move(r));
-}
-ExpressionNode::UP
-createVectorFloat(const std::vector<double> & v) {
- std::unique_ptr<FloatResultNodeVector> r = MU<FloatResultNodeVector>();
- for (double d : v) {
- r->push_back(FloatResultNode(d));
- }
- return MU<ConstantNode>(std::move(r));
-}
+
void testArithmeticArguments(NumericFunctionNode &function,
const std::vector<double> & arg1,
@@ -1516,6 +1590,7 @@ TEST("testStreamingAll") {
testStreaming(RawResultNode("Tester RawResultNode streaming", 30));
testStreaming(CountAggregationResult());
testStreaming(ExpressionCountAggregationResult());
+ testStreaming(StandardDeviationAggregationResult());
testStreaming(SumAggregationResult());
testStreaming(MinAggregationResult());
testStreaming(MaxAggregationResult());
diff --git a/searchlib/src/tests/grouping/grouping_serialization_test.cpp b/searchlib/src/tests/grouping/grouping_serialization_test.cpp
index 8592995915c..734c9095d48 100644
--- a/searchlib/src/tests/grouping/grouping_serialization_test.cpp
+++ b/searchlib/src/tests/grouping/grouping_serialization_test.cpp
@@ -229,6 +229,10 @@ TEST_F("testAggregatorResults", Fixture("testAggregatorResults")) {
expression_count.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(67)))
.aggregate(DocId(42), HitRank(21));
f.checkObject(expression_count);
+ StandardDeviationAggregationResult stddev;
+ stddev.setExpression(MU<ConstantNode>(MU<Int64ResultNode>(67)))
+ .aggregate(DocId(42), HitRank(21));
+ f.checkObject(stddev);
}
TEST_F("testHitCollection", Fixture("testHitCollection")) {
diff --git a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp
index 1cfc9804390..9ccad83df19 100644
--- a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp
+++ b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp
@@ -37,6 +37,7 @@ IMPLEMENT_AGGREGATIONRESULT(MinAggregationResult, AggregationResult);
IMPLEMENT_AGGREGATIONRESULT(AverageAggregationResult, AggregationResult);
IMPLEMENT_AGGREGATIONRESULT(XorAggregationResult, AggregationResult);
IMPLEMENT_AGGREGATIONRESULT(ExpressionCountAggregationResult, AggregationResult);
+IMPLEMENT_AGGREGATIONRESULT(StandardDeviationAggregationResult, AggregationResult);
AggregationResult::AggregationResult() :
_expressionTree(new ExpressionTree()),
@@ -501,6 +502,82 @@ Deserializer &ExpressionCountAggregationResult::onDeserialize(
ExpressionCountAggregationResult::ExpressionCountAggregationResult() : AggregationResult(), _hll() { }
ExpressionCountAggregationResult::~ExpressionCountAggregationResult() {}
+StandardDeviationAggregationResult::StandardDeviationAggregationResult()
+ : AggregationResult(), _count(), _sum(), _sumOfSquared(), _stdDevScratchPad()
+{
+ _stdDevScratchPad.reset(new expression::FloatResultNode());
+}
+
+StandardDeviationAggregationResult::~StandardDeviationAggregationResult() {}
+
+const NumericResultNode& StandardDeviationAggregationResult::getStandardDeviation() const noexcept
+{
+ if (_count == 0) {
+ _stdDevScratchPad->set(Int64ResultNode(0));
+ } else {
+ double variance = (_sumOfSquared.getFloat() - _sum.getFloat() * _sum.getFloat() / _count) / _count;
+ double stddev = std::sqrt(variance);
+ _stdDevScratchPad->set(FloatResultNode(stddev));
+ }
+ return *_stdDevScratchPad;
+}
+
+void StandardDeviationAggregationResult::onMerge(const AggregationResult &r) {
+ const StandardDeviationAggregationResult &result =
+ Identifiable::cast<const StandardDeviationAggregationResult &>(r);
+ _count += result._count;
+ _sum.add(result._sum);
+ _sumOfSquared.add(result._sumOfSquared);
+}
+
+void StandardDeviationAggregationResult::onAggregate(const ResultNode &result) {
+ if (result.isMultiValue()) {
+ static_cast<const ResultNodeVector &>(result).flattenSum(_sum);
+ static_cast<const ResultNodeVector &>(result).flattenSumOfSquared(_sumOfSquared);
+ _count += static_cast<const ResultNodeVector &>(result).size();
+ } else {
+ _sum.add(result);
+ FloatResultNode squared(result.getFloat());
+ squared.multiply(result);
+ _sumOfSquared.add(squared);
+ _count++;
+ }
+}
+
+void StandardDeviationAggregationResult::onReset()
+{
+ _count = 0;
+ _sum.set(0.0);
+ _sumOfSquared.set(0.0);
+}
+
+Serializer & StandardDeviationAggregationResult::onSerialize(Serializer & os) const
+{
+ AggregationResult::onSerialize(os);
+ double sum = _sum.getFloat();
+ double sumOfSquared = _sumOfSquared.getFloat();
+ return os << _count << sum << sumOfSquared;
+}
+
+Deserializer & StandardDeviationAggregationResult::onDeserialize(Deserializer & is)
+{
+ AggregationResult::onDeserialize(is);
+ double sum;
+ double sumOfSquared;
+ auto& r = is >> _count >> sum >> sumOfSquared;
+ _sum.set(sum);
+ _sumOfSquared.set(sumOfSquared);
+ return r;
+}
+
+void StandardDeviationAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const
+{
+ AggregationResult::visitMembers(visitor);
+ visit(visitor, "count", _count);
+ visit(visitor, "sum", _sum);
+ visit(visitor, "sumOfSquared", _sumOfSquared);
+}
+
} // namespace aggregation
} // namespace search
diff --git a/searchlib/src/vespa/searchlib/aggregation/aggregation.h b/searchlib/src/vespa/searchlib/aggregation/aggregation.h
index e17ee88c113..ad5d48b35b8 100644
--- a/searchlib/src/vespa/searchlib/aggregation/aggregation.h
+++ b/searchlib/src/vespa/searchlib/aggregation/aggregation.h
@@ -11,6 +11,7 @@
#include <vespa/searchlib/aggregation/averageaggregationresult.h>
#include <vespa/searchlib/aggregation/xoraggregationresult.h>
#include <vespa/searchlib/aggregation/hitsaggregationresult.h>
+#include <vespa/searchlib/aggregation/standarddeviationaggregationresult.h>
#include <vespa/searchlib/aggregation/grouping.h>
namespace search {
diff --git a/searchlib/src/vespa/searchlib/aggregation/standarddeviationaggregationresult.h b/searchlib/src/vespa/searchlib/aggregation/standarddeviationaggregationresult.h
new file mode 100644
index 00000000000..283637d5f0b
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/aggregation/standarddeviationaggregationresult.h
@@ -0,0 +1,35 @@
+// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#pragma once
+
+#include "aggregationresult.h"
+#include <vespa/searchlib/expression/floatresultnode.h>
+#include <vespa/searchlib/expression/integerresultnode.h>
+
+
+namespace search::aggregation {
+
+// Aggregator that calculates the population standard deviation
+class StandardDeviationAggregationResult : public AggregationResult
+{
+public:
+ DECLARE_AGGREGATIONRESULT(StandardDeviationAggregationResult);
+ StandardDeviationAggregationResult();
+ ~StandardDeviationAggregationResult();
+
+ void visitMembers(vespalib::ObjectVisitor &visitor) const override;
+ double getSum() const noexcept { return _sum.getFloat(); }
+ double getSumOfSquared() const noexcept { return _sumOfSquared.getFloat(); }
+ uint64_t getCount() const noexcept { return _count; }
+private:
+ const ResultNode& onGetRank() const noexcept override { return getStandardDeviation(); }
+ void onPrepare(const ResultNode&, bool) override { };
+ const expression::NumericResultNode& getStandardDeviation() const noexcept;
+
+ uint64_t _count;
+ expression::FloatResultNode _sum;
+ expression::FloatResultNode _sumOfSquared;
+ mutable expression::FloatResultNode::CP _stdDevScratchPad;
+};
+
+}
+
diff --git a/searchlib/src/vespa/searchlib/common/identifiable.h b/searchlib/src/vespa/searchlib/common/identifiable.h
index a2a2dfdb7bb..5ee131ccd56 100644
--- a/searchlib/src/vespa/searchlib/common/identifiable.h
+++ b/searchlib/src/vespa/searchlib/common/identifiable.h
@@ -90,6 +90,8 @@
#define CID_search_aggregation_HitsAggregationResult SEARCHLIB_CID(87)
#define CID_search_aggregation_ExpressionCountAggregationResult \
SEARCHLIB_CID(88)
+#define CID_search_aggregation_StandardDeviationAggregationResult \
+ SEARCHLIB_CID(89)
#define CID_search_aggregation_Group SEARCHLIB_CID(90)
#define CID_search_aggregation_Grouping SEARCHLIB_CID(91)
diff --git a/searchlib/src/vespa/searchlib/expression/resultvector.h b/searchlib/src/vespa/searchlib/expression/resultvector.h
index 9d6b409449e..f91354ce771 100644
--- a/searchlib/src/vespa/searchlib/expression/resultvector.h
+++ b/searchlib/src/vespa/searchlib/expression/resultvector.h
@@ -31,6 +31,7 @@ public:
virtual ResultNode & get(size_t index) = 0;
virtual void clear() = 0;
virtual void resize(size_t sz) = 0;
+ virtual void reserve(size_t sz) = 0;
size_t size() const { return onSize(); }
bool empty() const { return size() == 0; }
/**
@@ -44,6 +45,7 @@ public:
virtual ResultNode & flattenAnd(ResultNode & r) const { return r; }
virtual ResultNode & flattenOr(ResultNode & r) const { return r; }
virtual ResultNode & flattenXor(ResultNode & r) const { return r; }
+ virtual ResultNode & flattenSumOfSquared(ResultNode & r) const { return r; }
virtual void min(const ResultNode & b) { (void) b; }
virtual void max(const ResultNode & b) { (void) b; }
virtual void add(const ResultNode & b) { (void) b; }
@@ -92,6 +94,7 @@ public:
ResultNode & get(size_t index) override { return _result[index]; }
void clear() override { _result.clear(); }
void resize(size_t sz) override { _result.resize(sz); }
+ void reserve(size_t sz) override { _result.reserve(sz); }
void negate() override;
private:
void visitMembers(vespalib::ObjectVisitor &visitor) const override { visit(visitor, "Vector", _result); }
@@ -219,7 +222,7 @@ public:
B v;
v.set(r);
const std::vector<B> & vec(this->getVector());
- for(size_t i(0), m(vec.size()); i < m; i++) {
+ for (size_t i(0), m(vec.size()); i < m; i++) {
v.multiply(vec[i]);
}
r.set(v);
@@ -229,7 +232,7 @@ public:
Int64ResultNode v;
v.set(r);
const std::vector<B> & vec(this->getVector());
- for(size_t i(0), m(vec.size()); i < m; i++) {
+ for (size_t i(0), m(vec.size()); i < m; i++) {
v.andOp(vec[i]);
}
r.set(v);
@@ -239,7 +242,7 @@ public:
Int64ResultNode v;
v.set(r);
const std::vector<B> & vec(this->getVector());
- for(size_t i(0), m(vec.size()); i < m; i++) {
+ for (size_t i(0), m(vec.size()); i < m; i++) {
v.orOp(vec[i]);
}
r.set(v);
@@ -249,7 +252,7 @@ public:
Int64ResultNode v;
v.set(r);
const std::vector<B> & vec(this->getVector());
- for(size_t i(0), m(vec.size()); i < m; i++) {
+ for (size_t i(0), m(vec.size()); i < m; i++) {
v.xorOp(vec[i]);
}
r.set(v);
@@ -259,7 +262,7 @@ public:
B v;
v.set(r);
const std::vector<B> & vec(this->getVector());
- for(size_t i(0), m(vec.size()); i < m; i++) {
+ for (size_t i(0), m(vec.size()); i < m; i++) {
v.add(vec[i]);
}
r.set(v);
@@ -269,7 +272,7 @@ public:
B v;
v.set(r);
const std::vector<B> & vec(this->getVector());
- for(size_t i(0), m(vec.size()); i < m; i++) {
+ for (size_t i(0), m(vec.size()); i < m; i++) {
v.max(vec[i]);
}
r.set(v);
@@ -279,12 +282,25 @@ public:
B v;
v.set(r);
const std::vector<B> & vec(this->getVector());
- for(size_t i(0), m(vec.size()); i < m; i++) {
+ for (size_t i(0), m(vec.size()); i < m; i++) {
v.min(vec[i]);
}
r.set(v);
return r;
}
+ ResultNode & flattenSumOfSquared(ResultNode & r) const override {
+ B v;
+ v.set(r);
+ const std::vector<B> & vec(this->getVector());
+ for (size_t i(0), m(vec.size()); i < m; i++) {
+ B squared;
+ squared.set(vec[i]);
+ squared.multiply(vec[i]);
+ v.add(squared);
+ }
+ r.set(v);
+ return r;
+ }
};
@@ -400,6 +416,7 @@ public:
ResultNode & get(size_t index) override { return *_v[index]; }
void clear() override { _v.clear(); }
void resize(size_t sz) override { _v.resize(sz); }
+ void reserve(size_t sz) override { _v.reserve(sz); }
private:
int64_t onGetInteger(size_t index) const override { return _v[index]->getInteger(index); }
double onGetFloat(size_t index) const override { return _v[index]->getFloat(index); }