summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2020-11-02 11:20:14 +0100
committerJon Bratseth <bratseth@gmail.com>2020-11-02 11:20:14 +0100
commit432d35c0d4cc761c6739e63de5dbb6197a369a3d (patch)
treec2a2b49354bebcf1499b6d08161b1c026c73deee /vespajlib
parentac8b4ebae4796b275ff71cc15eb259a22797a913 (diff)
Add median aggregator
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java92
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java35
4 files changed, 114 insertions, 34 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 7e7e376a8df..c6727aa372e 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1097,6 +1097,7 @@
"abstract"
],
"methods": [
+ "public static com.yahoo.tensor.Tensor$Builder of(java.lang.String)",
"public static com.yahoo.tensor.Tensor$Builder of(com.yahoo.tensor.TensorType)",
"public static com.yahoo.tensor.Tensor$Builder of(com.yahoo.tensor.TensorType, com.yahoo.tensor.DimensionSizes)",
"public abstract com.yahoo.tensor.TensorType type()",
@@ -1202,6 +1203,9 @@
"public com.yahoo.tensor.Tensor max()",
"public com.yahoo.tensor.Tensor max(java.lang.String)",
"public com.yahoo.tensor.Tensor max(java.util.List)",
+ "public com.yahoo.tensor.Tensor median()",
+ "public com.yahoo.tensor.Tensor median(java.lang.String)",
+ "public com.yahoo.tensor.Tensor median(java.util.List)",
"public com.yahoo.tensor.Tensor min()",
"public com.yahoo.tensor.Tensor min(java.lang.String)",
"public com.yahoo.tensor.Tensor min(java.util.List)",
@@ -1827,10 +1831,11 @@
"fields": [
"public static final enum com.yahoo.tensor.functions.Reduce$Aggregator avg",
"public static final enum com.yahoo.tensor.functions.Reduce$Aggregator count",
- "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator prod",
- "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator sum",
"public static final enum com.yahoo.tensor.functions.Reduce$Aggregator max",
- "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator min"
+ "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator median",
+ "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator min",
+ "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator prod",
+ "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator sum"
]
},
"com.yahoo.tensor.functions.Reduce": {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 68b5cf8a946..0fba2ca4875 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -92,7 +92,7 @@ public interface Tensor {
*/
default double asDouble() {
if (type().dimensions().size() > 0)
- throw new IllegalStateException("This tensor is not dimensionless. Dimensions: " + type().dimensions().size());
+ throw new IllegalStateException("Require a dimensionless tensor but has " + type());
if (size() == 0) return Double.NaN;
return valueIterator().next();
}
@@ -242,6 +242,9 @@ public interface Tensor {
default Tensor max() { return max(Collections.emptyList()); }
default Tensor max(String dimension) { return max(Collections.singletonList(dimension)); }
default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); }
+ default Tensor median() { return median(Collections.emptyList()); }
+ default Tensor median(String dimension) { return median(Collections.singletonList(dimension)); }
+ default Tensor median(List<String> dimensions) { return reduce(Reduce.Aggregator.median, dimensions); }
default Tensor min() { return min(Collections.emptyList()); }
default Tensor min(String dimension) { return min(Collections.singletonList(dimension)); }
default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); }
@@ -469,6 +472,11 @@ public interface Tensor {
interface Builder {
+ /** Creates a suitable builder for the given type spec */
+ static Builder of(String typeSpec) {
+ return of(TensorType.fromSpec(typeSpec));
+ }
+
/** Creates a suitable builder for the given type */
static Builder of(TensorType type) {
boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 1eb09a603fa..48604df87e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@@ -27,7 +28,7 @@ import java.util.Set;
*/
public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
- public enum Aggregator { avg, count, prod, sum, max, min; }
+ public enum Aggregator { avg, count, max, median, min, prod, sum ; }
private final TensorFunction<NAMETYPE> argument;
private final List<String> dimensions;
@@ -53,11 +54,8 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
* @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
*/
public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, List<String> dimensions) {
- Objects.requireNonNull(argument, "The argument tensor cannot be null");
- Objects.requireNonNull(aggregator, "The aggregator cannot be null");
- Objects.requireNonNull(dimensions, "The dimensions cannot be null");
- this.argument = argument;
- this.aggregator = aggregator;
+ this.argument = Objects.requireNonNull(argument, "The argument tensor cannot be null");
+ this.aggregator = Objects.requireNonNull(aggregator, "The aggregator cannot be null");
this.dimensions = ImmutableList.copyOf(dimensions);
}
@@ -186,10 +184,11 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
switch (aggregator) {
case avg : return new AvgAggregator();
case count : return new CountAggregator();
- case prod : return new ProdAggregator();
- case sum : return new SumAggregator();
case max : return new MaxAggregator();
+ case median : return new MedianAggregator();
case min : return new MinAggregator();
+ case prod : return new ProdAggregator();
+ case sum : return new SumAggregator();
default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
}
@@ -249,87 +248,120 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- private static class ProdAggregator extends ValueAggregator {
+ private static class MaxAggregator extends ValueAggregator {
- private double valueProd = 1.0;
+ private double maxValue = Double.MIN_VALUE;
@Override
public void aggregate(double value) {
- valueProd *= value;
+ if (value > maxValue)
+ maxValue = value;
}
@Override
public double aggregatedValue() {
- return valueProd;
+ return maxValue;
}
@Override
public void reset() {
- valueProd = 1.0;
+ maxValue = Double.MIN_VALUE;
}
}
- private static class SumAggregator extends ValueAggregator {
+ private static class MedianAggregator extends ValueAggregator {
- private double valueSum = 0.0;
+ /** If any NaN is added, the result should be NaN */
+ private boolean isNaN = false;
+
+ private List<Double> values = new ArrayList<>();
@Override
public void aggregate(double value) {
- valueSum += value;
+ if ( Double.isNaN(value))
+ isNaN = true;
+ if ( ! isNaN)
+ values.add(value);
}
@Override
public double aggregatedValue() {
- return valueSum;
+ if (isNaN || values.isEmpty()) return Double.NaN;
+ Collections.sort(values);
+ if (values.size() % 2 == 0) // even: average the two middle values
+ return ( values.get(values.size() / 2 - 1) + values.get(values.size() / 2) ) / 2;
+ else
+ return values.get((values.size() - 1)/ 2);
}
@Override
public void reset() {
- valueSum = 0.0;
+ isNaN = false;
+ values = new ArrayList<>();
}
+
}
- private static class MaxAggregator extends ValueAggregator {
+ private static class MinAggregator extends ValueAggregator {
- private double maxValue = Double.MIN_VALUE;
+ private double minValue = Double.MAX_VALUE;
@Override
public void aggregate(double value) {
- if (value > maxValue)
- maxValue = value;
+ if (value < minValue)
+ minValue = value;
}
@Override
public double aggregatedValue() {
- return maxValue;
+ return minValue;
}
@Override
public void reset() {
- maxValue = Double.MIN_VALUE;
+ minValue = Double.MAX_VALUE;
}
+
}
- private static class MinAggregator extends ValueAggregator {
+ private static class ProdAggregator extends ValueAggregator {
- private double minValue = Double.MAX_VALUE;
+ private double valueProd = 1.0;
@Override
public void aggregate(double value) {
- if (value < minValue)
- minValue = value;
+ valueProd *= value;
}
@Override
public double aggregatedValue() {
- return minValue;
+ return valueProd;
}
@Override
public void reset() {
- minValue = Double.MAX_VALUE;
+ valueProd = 1.0;
}
+ }
+
+ private static class SumAggregator extends ValueAggregator {
+ private double valueSum = 0.0;
+
+ @Override
+ public void aggregate(double value) {
+ valueSum += value;
+ }
+
+ @Override
+ public double aggregatedValue() {
+ return valueSum;
+ }
+
+ @Override
+ public void reset() {
+ valueSum = 0.0;
+ }
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java
new file mode 100644
index 00000000000..21fed1745b9
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java
@@ -0,0 +1,35 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author bratseth
+ */
+public class ReduceTestCase {
+
+ private static final double delta = 0.00000001;
+
+ @Test
+ public void testReduce() {
+ assertNan(Tensor.from("tensor(x{})", "{}").median());
+ assertEquals(1.0, Tensor.from("tensor(x[1])", "[1]").median().asDouble(), delta);
+ assertEquals(1.5, Tensor.from("tensor(x[2])", "[1, 2]").median().asDouble(), delta);
+ assertEquals(3.0, Tensor.from("tensor(x[7])", "[3, 1, 1, 1, 4, 4, 4]").median().asDouble(), delta);
+ assertEquals(2.0, Tensor.from("tensor(x[6])", "[3, 1, 1, 1, 4, 4]").median().asDouble(), delta);
+ assertEquals(2.0, Tensor.from("tensor(x{})", "{{x: foo}: 3, {x:bar}: 1}").median().asDouble(), delta);
+
+ assertNan(Tensor.Builder.of("tensor(x[3])").cell(Double.NaN, 0).cell(1, 1).cell(2, 2).build().median());
+ assertNan(Tensor.Builder.of("tensor(x[3])").cell(Double.NaN, 2).cell(1, 1).cell(2, 0).build().median());
+ assertNan(Tensor.Builder.of("tensor(x[1])").cell(Double.NaN, 0).build().median());
+ }
+
+ private void assertNan(Tensor tensor) {
+ assertTrue(tensor + " is NaN", Double.isNaN(tensor.asDouble()));
+ }
+
+}