diff options
Diffstat (limited to 'vespajlib')
4 files changed, 114 insertions, 34 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 7e7e376a8df..c6727aa372e 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1097,6 +1097,7 @@ "abstract" ], "methods": [ + "public static com.yahoo.tensor.Tensor$Builder of(java.lang.String)", "public static com.yahoo.tensor.Tensor$Builder of(com.yahoo.tensor.TensorType)", "public static com.yahoo.tensor.Tensor$Builder of(com.yahoo.tensor.TensorType, com.yahoo.tensor.DimensionSizes)", "public abstract com.yahoo.tensor.TensorType type()", @@ -1202,6 +1203,9 @@ "public com.yahoo.tensor.Tensor max()", "public com.yahoo.tensor.Tensor max(java.lang.String)", "public com.yahoo.tensor.Tensor max(java.util.List)", + "public com.yahoo.tensor.Tensor median()", + "public com.yahoo.tensor.Tensor median(java.lang.String)", + "public com.yahoo.tensor.Tensor median(java.util.List)", "public com.yahoo.tensor.Tensor min()", "public com.yahoo.tensor.Tensor min(java.lang.String)", "public com.yahoo.tensor.Tensor min(java.util.List)", @@ -1827,10 +1831,11 @@ "fields": [ "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator avg", "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator count", - "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator prod", - "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator sum", "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator max", - "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator min" + "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator median", + "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator min", + "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator prod", + "public static final enum com.yahoo.tensor.functions.Reduce$Aggregator sum" ] }, "com.yahoo.tensor.functions.Reduce": { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 68b5cf8a946..0fba2ca4875 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -92,7 +92,7 @@ public interface Tensor { */ default double asDouble() { if (type().dimensions().size() > 0) - throw new IllegalStateException("This tensor is not dimensionless. Dimensions: " + type().dimensions().size()); + throw new IllegalStateException("Require a dimensionless tensor but has " + type()); if (size() == 0) return Double.NaN; return valueIterator().next(); } @@ -242,6 +242,9 @@ public interface Tensor { default Tensor max() { return max(Collections.emptyList()); } default Tensor max(String dimension) { return max(Collections.singletonList(dimension)); } default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); } + default Tensor median() { return median(Collections.emptyList()); } + default Tensor median(String dimension) { return median(Collections.singletonList(dimension)); } + default Tensor median(List<String> dimensions) { return reduce(Reduce.Aggregator.median, dimensions); } default Tensor min() { return min(Collections.emptyList()); } default Tensor min(String dimension) { return min(Collections.singletonList(dimension)); } default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); } @@ -469,6 +472,11 @@ public interface Tensor { interface Builder { + /** Creates a suitable builder for the given type spec */ + static Builder of(String typeSpec) { + return of(TensorType.fromSpec(typeSpec)); + } + /** Creates a suitable builder for the given type */ static Builder of(TensorType type) { boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 1eb09a603fa..48604df87e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -27,7 +28,7 @@ import java.util.Set; */ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { - public enum Aggregator { avg, count, prod, sum, max, min; } + public enum Aggregator { avg, count, max, median, min, prod, sum ; } private final TensorFunction<NAMETYPE> argument; private final List<String> dimensions; @@ -53,11 +54,8 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor */ public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, List<String> dimensions) { - Objects.requireNonNull(argument, "The argument tensor cannot be null"); - Objects.requireNonNull(aggregator, "The aggregator cannot be null"); - Objects.requireNonNull(dimensions, "The dimensions cannot be null"); - this.argument = argument; - this.aggregator = aggregator; + this.argument = Objects.requireNonNull(argument, "The argument tensor cannot be null"); + this.aggregator = Objects.requireNonNull(aggregator, "The aggregator cannot be null"); this.dimensions = ImmutableList.copyOf(dimensions); } @@ -186,10 +184,11 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET switch (aggregator) { case avg : return new AvgAggregator(); case count : return new CountAggregator(); - case prod : return new ProdAggregator(); - case sum : return new SumAggregator(); case max : return new MaxAggregator(); + case median : return new MedianAggregator(); case min : return new MinAggregator(); + case prod : return new ProdAggregator(); + case sum : return new SumAggregator(); default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); } @@ -249,87 +248,120 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } } - private static class ProdAggregator extends ValueAggregator { + private static class MaxAggregator extends ValueAggregator { - private double valueProd = 1.0; + private double maxValue = Double.MIN_VALUE; @Override public void aggregate(double value) { - valueProd *= value; + if (value > maxValue) + maxValue = value; } @Override public double aggregatedValue() { - return valueProd; + return maxValue; } @Override public void reset() { - valueProd = 1.0; + maxValue = Double.MIN_VALUE; } } - private static class SumAggregator extends ValueAggregator { + private static class MedianAggregator extends ValueAggregator { - private double valueSum = 0.0; + /** If any NaN is added, the result should be NaN */ + private boolean isNaN = false; + + private List<Double> values = new ArrayList<>(); @Override public void aggregate(double value) { - valueSum += value; + if ( Double.isNaN(value)) + isNaN = true; + if ( ! isNaN) + values.add(value); } @Override public double aggregatedValue() { - return valueSum; + if (isNaN || values.isEmpty()) return Double.NaN; + Collections.sort(values); + if (values.size() % 2 == 0) // even: average the two middle values + return ( values.get(values.size() / 2 - 1) + values.get(values.size() / 2) ) / 2; + else + return values.get((values.size() - 1)/ 2); } @Override public void reset() { - valueSum = 0.0; + isNaN = false; + values = new ArrayList<>(); } + } - private static class MaxAggregator extends ValueAggregator { + private static class MinAggregator extends ValueAggregator { - private double maxValue = Double.MIN_VALUE; + private double minValue = Double.MAX_VALUE; @Override public void aggregate(double value) { - if (value > maxValue) - maxValue = value; + if (value < minValue) + minValue = value; } @Override public double aggregatedValue() { - return maxValue; + return minValue; } @Override public void reset() { - maxValue = Double.MIN_VALUE; + minValue = Double.MAX_VALUE; } + } - private static class MinAggregator extends ValueAggregator { + private static class ProdAggregator extends ValueAggregator { - private double minValue = Double.MAX_VALUE; + private double valueProd = 1.0; @Override public void aggregate(double value) { - if (value < minValue) - minValue = value; + valueProd *= value; } @Override public double aggregatedValue() { - return minValue; + return valueProd; } @Override public void reset() { - minValue = Double.MAX_VALUE; + valueProd = 1.0; } + } + + private static class SumAggregator extends ValueAggregator { + private double valueSum = 0.0; + + @Override + public void aggregate(double value) { + valueSum += value; + } + + @Override + public double aggregatedValue() { + return valueSum; + } + + @Override + public void reset() { + valueSum = 0.0; + } } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java new file mode 100644 index 00000000000..21fed1745b9 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceTestCase.java @@ -0,0 +1,35 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author bratseth + */ +public class ReduceTestCase { + + private static final double delta = 0.00000001; + + @Test + public void testReduce() { + assertNan(Tensor.from("tensor(x{})", "{}").median()); + assertEquals(1.0, Tensor.from("tensor(x[1])", "[1]").median().asDouble(), delta); + assertEquals(1.5, Tensor.from("tensor(x[2])", "[1, 2]").median().asDouble(), delta); + assertEquals(3.0, Tensor.from("tensor(x[7])", "[3, 1, 1, 1, 4, 4, 4]").median().asDouble(), delta); + assertEquals(2.0, Tensor.from("tensor(x[6])", "[3, 1, 1, 1, 4, 4]").median().asDouble(), delta); + assertEquals(2.0, Tensor.from("tensor(x{})", "{{x: foo}: 3, {x:bar}: 1}").median().asDouble(), delta); + + assertNan(Tensor.Builder.of("tensor(x[3])").cell(Double.NaN, 0).cell(1, 1).cell(2, 2).build().median()); + assertNan(Tensor.Builder.of("tensor(x[3])").cell(Double.NaN, 2).cell(1, 1).cell(2, 0).build().median()); + assertNan(Tensor.Builder.of("tensor(x[1])").cell(Double.NaN, 0).build().median()); + } + + private void assertNan(Tensor tensor) { + assertTrue(tensor + " is NaN", Double.isNaN(tensor.asDouble())); + } + +} |