summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahoo-inc.com>2017-05-04 12:34:10 +0000
committerBjørn Christian Seime <bjorncs@yahoo-inc.com>2017-05-08 12:59:40 +0200
commitbd9d968854e1723b8f09f6d9ba320e9b3daf774e (patch)
tree0c11f46a573c37b5bf234c14dcfca0af36845db2 /searchlib/src/main
parentb374b62e10cd3e97a74234117546ba95d29403c1 (diff)
Add aggregator for calculating the population standard deviation
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java90
1 files changed, 90 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java
new file mode 100644
index 00000000000..4edeec7ba77
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/StandardDeviationAggregationResult.java
@@ -0,0 +1,90 @@
+// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.aggregation;
+
+import com.yahoo.searchlib.expression.FloatResultNode;
+import com.yahoo.searchlib.expression.ResultNode;
+import com.yahoo.vespa.objects.Deserializer;
+import com.yahoo.vespa.objects.ObjectVisitor;
+import com.yahoo.vespa.objects.Serializer;
+
+/**
+ * @author bjorncs
+ */
+public class StandardDeviationAggregationResult extends AggregationResult {
+ public static final int classId = registerClass(0x4000 + 89, StandardDeviationAggregationResult.class);
+
+ private long count;
+ private double sum;
+ private double sumOfSquared;
+
+ /**
+ * Constructor used for deserialization. Will be instantiated with a default sketch.
+ */
+ @SuppressWarnings("unused")
+ public StandardDeviationAggregationResult() {
+ this(0, 0.0, 0.0);
+ }
+
+ public StandardDeviationAggregationResult(long count, double sum, double sumOfSquared) {
+ this.count = count;
+ this.sum = sum;
+ this.sumOfSquared = sumOfSquared;
+ }
+
+ public double getStandardDeviation() {
+ if (count == 0) {
+ return 0;
+ } else {
+ double variance = (sumOfSquared - sum * sum / count) / count;
+ return Math.sqrt(variance);
+ }
+ }
+
+ @Override
+ public ResultNode getRank() {
+ return new FloatResultNode(getStandardDeviation());
+ }
+
+ @Override
+ protected void onMerge(AggregationResult obj) {
+ StandardDeviationAggregationResult other = (StandardDeviationAggregationResult) obj;
+ count += other.count;
+ sum += other.sum;
+ sumOfSquared += other.sumOfSquared;
+ }
+
+ @Override
+ protected boolean equalsAggregation(AggregationResult obj) {
+ StandardDeviationAggregationResult other = (StandardDeviationAggregationResult) obj;
+ return count == this.count && sum == other.sum && sumOfSquared == other.sumOfSquared;
+ }
+
+ @Override
+ protected void onSerialize(Serializer buf) {
+ super.onSerialize(buf);
+ buf.putLong(null, count);
+ buf.putDouble(null, sum);
+ buf.putDouble(null, sumOfSquared);
+ }
+
+ @Override
+ protected void onDeserialize(Deserializer buf) {
+ super.onDeserialize(buf);
+ count = buf.getLong(null);
+ sum = buf.getDouble(null);
+ sumOfSquared = buf.getDouble(null);
+ }
+
+ @Override
+ protected int onGetClassId() {
+ return classId;
+ }
+
+ @Override
+ public void visitMembers(ObjectVisitor visitor) {
+ super.visitMembers(visitor);
+ visitor.visit("count", count);
+ visitor.visit("sum", sum);
+ visitor.visit("sumOfSquared", sumOfSquared);
+ }
+}