aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2020-11-02 11:20:14 +0100
committerJon Bratseth <bratseth@gmail.com>2020-11-02 11:20:14 +0100
commit432d35c0d4cc761c6739e63de5dbb6197a369a3d (patch)
treec2a2b49354bebcf1499b6d08161b1c026c73deee
parentac8b4ebae4796b275ff71cc15eb259a22797a913 (diff)
Add median aggregator
-rw-r--r--searchlib/abi-spec.json5
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java12
-rw-r--r--vespajlib/abi-spec.json11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java92
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java35
7 files changed, 129 insertions, 43 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index aca88ff864e..88eccb4559f 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1049,10 +1049,11 @@
"public static final int ARGMIN",
"public static final int AVG",
"public static final int COUNT",
- "public static final int PROD",
- "public static final int SUM",
"public static final int MAX",
+ "public static final int MEDIAN",
"public static final int MIN",
+ "public static final int PROD",
+ "public static final int SUM",
"public static final int IDENTIFIER",
"public static final int SINGLE_LINE_COMMENT",
"public static final int DEFAULT",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 5f27bbcbeee..09880b8dfc3 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -144,10 +144,11 @@ TOKEN :
<AVG: "avg" > |
<COUNT: "count"> |
- <PROD: "prod"> |
- <SUM: "sum"> |
<MAX: "max"> |
+ <MEDIAN: "median"> |
<MIN: "min"> |
+ <PROD: "prod"> |
+ <SUM: "sum"> |
<IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)>
}
@@ -630,7 +631,7 @@ Reduce.Aggregator tensorReduceAggregator() :
{
}
{
- ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> )
+ ( <AVG> | <COUNT> | <MAX> | <MEDIAN> | <MIN> | <PROD> | <SUM> )
{ return Reduce.Aggregator.valueOf(token.image); }
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index f1c421f6c22..1bf4dc5698d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -194,14 +194,16 @@ public class EvaluationTestCase {
"reduce(tensor0, avg, x, y)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
tester.assertEvaluates("{ {}:4 }",
"reduce(tensor0, count, x, y)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
- tester.assertEvaluates("{ {}:105 }",
- "reduce(tensor0, prod, x, y)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
- tester.assertEvaluates("{ {}:16 }",
- "reduce(tensor0, sum, x, y)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
tester.assertEvaluates("{ {}:7 }",
"reduce(tensor0, max, x, y)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
+ tester.assertEvaluates("{ {}:4 }",
+ "reduce(tensor0, median, x, y)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
tester.assertEvaluates("{ {}:1 }",
"reduce(tensor0, min, x, y)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
+ tester.assertEvaluates("{ {}:105 }",
+ "reduce(tensor0, prod, x, y)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
+ tester.assertEvaluates("{ {}:16 }",
+ "reduce(tensor0, sum, x, y)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
// -- reduce 2 by specifying no arguments
tester.assertEvaluates("{ {}:4 }",
"reduce(tensor0, avg)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
@@ -223,6 +225,8 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {}:-5 }", "sum(tensor0)", "-5.0");
tester.assertEvaluates("{ {}:12.5 }", "sum(tensor0)", "{ {d1:0}:5.5, {d1:1}:7.0 }");
tester.assertEvaluates("{ {}: 0 }", "sum(tensor0)", "{ {d1:0}:5.0, {d1:1}:7.0, {d1:2}:-12.0}");
+ tester.assertEvaluates("{ {}: 8.0 }", "avg(tensor0)", "{ {d1:0}:5.0, {d1:1}:7.0, {d1:2}:12.0}");
+ tester.assertEvaluates("{ {}: 5.0 }", "median(tensor0)", "{ {d1:0}:5.0, {d1:1}:7.0, {d1:2}:-12.0}");
tester.assertEvaluates("{ {y:0}:4, {y:1}:12.0 }",
"sum(tensor0, x)", "{ {x:0,y:0}:1.0, {x:1,y:0}:3.0, {x:0,y:1}:5.0, {x:1,y:1}:7.0 }");
tester.assertEvaluates("{ {x:0}:6, {x:1}:10.0 }",
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 7e7e376a8df..c6727aa372e 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1097,6 +1097,7 @@
"abstract"
],
"methods": [
+ "public static com.yahoo.tensor.Tensor$Builder of(java.lang.String)",
"public static com.yahoo.tensor.Tensor$Builder of(com.yahoo.tensor.TensorType)",
"public static com.yahoo.tensor.Tensor$Builder of(com.yahoo.tensor.TensorType, com.yahoo.tensor.DimensionSizes)",
"public abstract com.yahoo.tensor.TensorType type()",
@@ -1202,6 +1203,9 @@
"public com.yahoo.tensor.Tensor max()",
"public com.yahoo.tensor.Tensor max(java.lang.String)",
"public com.yahoo.tensor.Tensor max(java.util.List)",
+ "public com.yahoo.tensor.Tensor median()",
+ "public com.yahoo.tensor.Tensor median(java.lang.String)",
+ "public com.yahoo.tensor.Tensor median(java.util.List)",
"public com.yahoo.tensor.Tensor min()",
"public com.yahoo.tensor.Tensor min(java.lang.String)",
"public com.yahoo.tensor.Tensor min(java.util.List)",
@@ -1827,10 +1831,11 @@
"fields": [
"public static final enum com.yahoo.tensor.functions.Reduce$Aggregator avg",
"public static final enum com.yahoo.tensor.functions.Reduce$Aggregator count",
- "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator prod",
- "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator sum",
"public static final enum com.yahoo.tensor.functions.Reduce$Aggregator max",
- "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator min"
+ "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator median",
+ "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator min",
+ "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator prod",
+ "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator sum"
]
},
"com.yahoo.tensor.functions.Reduce": {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 68b5cf8a946..0fba2ca4875 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -92,7 +92,7 @@ public interface Tensor {
*/
default double asDouble() {
if (type().dimensions().size() > 0)
- throw new IllegalStateException("This tensor is not dimensionless. Dimensions: " + type().dimensions().size());
+ throw new IllegalStateException("Require a dimensionless tensor but has " + type());
if (size() == 0) return Double.NaN;
return valueIterator().next();
}
@@ -242,6 +242,9 @@ public interface Tensor {
default Tensor max() { return max(Collections.emptyList()); }
default Tensor max(String dimension) { return max(Collections.singletonList(dimension)); }
default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); }
+ default Tensor median() { return median(Collections.emptyList()); }
+ default Tensor median(String dimension) { return median(Collections.singletonList(dimension)); }
+ default Tensor median(List<String> dimensions) { return reduce(Reduce.Aggregator.median, dimensions); }
default Tensor min() { return min(Collections.emptyList()); }
default Tensor min(String dimension) { return min(Collections.singletonList(dimension)); }
default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); }
@@ -469,6 +472,11 @@ public interface Tensor {
interface Builder {
+ /** Creates a suitable builder for the given type spec */
+ static Builder of(String typeSpec) {
+ return of(TensorType.fromSpec(typeSpec));
+ }
+
/** Creates a suitable builder for the given type */
static Builder of(TensorType type) {
boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 1eb09a603fa..48604df87e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@@ -27,7 +28,7 @@ import java.util.Set;
*/
public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
- public enum Aggregator { avg, count, prod, sum, max, min; }
+ public enum Aggregator { avg, count, max, median, min, prod, sum ; }
private final TensorFunction<NAMETYPE> argument;
private final List<String> dimensions;
@@ -53,11 +54,8 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
* @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
*/
public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, List<String> dimensions) {
- Objects.requireNonNull(argument, "The argument tensor cannot be null");
- Objects.requireNonNull(aggregator, "The aggregator cannot be null");
- Objects.requireNonNull(dimensions, "The dimensions cannot be null");
- this.argument = argument;
- this.aggregator = aggregator;
+ this.argument = Objects.requireNonNull(argument, "The argument tensor cannot be null");
+ this.aggregator = Objects.requireNonNull(aggregator, "The aggregator cannot be null");
this.dimensions = ImmutableList.copyOf(dimensions);
}
@@ -186,10 +184,11 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
switch (aggregator) {
case avg : return new AvgAggregator();
case count : return new CountAggregator();
- case prod : return new ProdAggregator();
- case sum : return new SumAggregator();
case max : return new MaxAggregator();
+ case median : return new MedianAggregator();
case min : return new MinAggregator();
+ case prod : return new ProdAggregator();
+ case sum : return new SumAggregator();
default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
}
@@ -249,87 +248,120 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- private static class ProdAggregator extends ValueAggregator {
+ private static class MaxAggregator extends ValueAggregator {
- private double valueProd = 1.0;
+ private double maxValue = Double.MIN_VALUE;
@Override
public void aggregate(double value) {
- valueProd *= value;
+ if (value > maxValue)
+ maxValue = value;
}
@Override
public double aggregatedValue() {
- return valueProd;
+ return maxValue;
}
@Override
public void reset() {
- valueProd = 1.0;
+ maxValue = Double.MIN_VALUE;
}
}
- private static class SumAggregator extends ValueAggregator {
+ private static class MedianAggregator extends ValueAggregator {
- private double valueSum = 0.0;
+ /** If any NaN is added, the result should be NaN */
+ private boolean isNaN = false;
+
+ private List<Double> values = new ArrayList<>();
@Override
public void aggregate(double value) {
- valueSum += value;
+ if ( Double.isNaN(value))
+ isNaN = true;
+ if ( ! isNaN)
+ values.add(value);
}
@Override
public double aggregatedValue() {
- return valueSum;
+ if (isNaN || values.isEmpty()) return Double.NaN;
+ Collections.sort(values);
+ if (values.size() % 2 == 0) // even: average the two middle values
+ return ( values.get(values.size() / 2 - 1) + values.get(values.size() / 2) ) / 2;
+ else
+ return values.get((values.size() - 1)/ 2);
}
@Override
public void reset() {
- valueSum = 0.0;
+ isNaN = false;
+ values = new ArrayList<>();
}
+
}
- private static class MaxAggregator extends ValueAggregator {
+ private static class MinAggregator extends ValueAggregator {
- private double maxValue = Double.MIN_VALUE;
+ private double minValue = Double.MAX_VALUE;
@Override
public void aggregate(double value) {
- if (value > maxValue)
- maxValue = value;
+ if (value < minValue)
+ minValue = value;
}
@Override
public double aggregatedValue() {
- return maxValue;
+ return minValue;
}
@Override
public void reset() {
- maxValue = Double.MIN_VALUE;
+ minValue = Double.MAX_VALUE;
}
+
}
- private static class MinAggregator extends ValueAggregator {
+ private static class ProdAggregator extends ValueAggregator {
- private double minValue = Double.MAX_VALUE;
+ private double valueProd = 1.0;
@Override
public void aggregate(double value) {
- if (value < minValue)
- minValue = value;
+ valueProd *= value;
}
@Override
public double aggregatedValue() {
- return minValue;
+ return valueProd;
}
@Override
public void reset() {
- minValue = Double.MAX_VALUE;
+ valueProd = 1.0;
}
+ }
+
+ private static class SumAggregator extends ValueAggregator {
+ private double valueSum = 0.0;
+
+ @Override
+ public void aggregate(double value) {
+ valueSum += value;
+ }
+
+ @Override
+ public double aggregatedValue() {
+ return valueSum;
+ }
+
+ @Override
+ public void reset() {
+ valueSum = 0.0;
+ }
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java
new file mode 100644
index 00000000000..21fed1745b9
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java
@@ -0,0 +1,35 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author bratseth
+ */
+public class ReduceTestCase {
+
+ private static final double delta = 0.00000001;
+
+ @Test
+ public void testReduce() {
+ assertNan(Tensor.from("tensor(x{})", "{}").median());
+ assertEquals(1.0, Tensor.from("tensor(x[1])", "[1]").median().asDouble(), delta);
+ assertEquals(1.5, Tensor.from("tensor(x[2])", "[1, 2]").median().asDouble(), delta);
+ assertEquals(3.0, Tensor.from("tensor(x[7])", "[3, 1, 1, 1, 4, 4, 4]").median().asDouble(), delta);
+ assertEquals(2.0, Tensor.from("tensor(x[6])", "[3, 1, 1, 1, 4, 4]").median().asDouble(), delta);
+ assertEquals(2.0, Tensor.from("tensor(x{})", "{{x: foo}: 3, {x:bar}: 1}").median().asDouble(), delta);
+
+ assertNan(Tensor.Builder.of("tensor(x[3])").cell(Double.NaN, 0).cell(1, 1).cell(2, 2).build().median());
+ assertNan(Tensor.Builder.of("tensor(x[3])").cell(Double.NaN, 2).cell(1, 1).cell(2, 0).build().median());
+ assertNan(Tensor.Builder.of("tensor(x[1])").cell(Double.NaN, 0).build().median());
+ }
+
+ private void assertNan(Tensor tensor) {
+ assertTrue(tensor + " is NaN", Double.isNaN(tensor.asDouble()));
+ }
+
+}