diff options
7 files changed, 10 insertions, 214 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 788f5841df9..536be62541e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.TensorType; +import java.util.Collections; import java.util.Optional; /** @@ -17,7 +18,7 @@ import java.util.Optional; * * @author bratseth */ - @Beta +@Beta public class TensorValue extends Value { /** The tensor value of this */ @@ -97,11 +98,11 @@ public class TensorValue extends Value { } public Value sum(String dimension) { - return new TensorValue(value.sum(dimension)); + return new TensorValue(value.sum(Collections.singletonList(dimension))); } public Value sum() { - return new DoubleValue(value.sum()); + return new TensorValue(value.sum(Collections.emptyList())); } private Tensor asTensor(Value value, String operationName) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index d752fb1fccd..e442d9823d5 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -149,54 +149,14 @@ public class EvaluationTestCase extends junit.framework.TestCase { "tensor0 - tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:4 }", "tensor0 / tensor1", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }"); + assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:7 }", + "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:3, {x:2,y:1}:5 }", + "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }", "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }"); - // min - assertEvaluates("{ {x:1}:3, {x:2}:5 }", - "min(tensor0, tensor1)", "{ {x:1}:3 }", "{ {x:2}:5 }"); - assertEvaluates("{ {x:1}:3 }", - "min(tensor0, tensor1)", "{ {x:1}:3 }", "{ {x:1}:5 }"); - assertEvaluates("{ {x:1}:3, {y:1}:5 }", - "min(tensor0, tensor1)", "{ {x:1}:3 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", - "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", - "min(tensor0, tensor1)", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "min(tensor0, tensor1)", "{ {x:1}:5, {x:1,y:1}:1 }", "{ {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", - "min(tensor0, tensor1)", "{ {x:1}:5, {x:1,y:1}:1 }", "{ {z:1}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "min(tensor0, tensor1)", "{ {}:5, {x:1,y:1}:1 }", "{ {y:1,z:1}:7 }"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "min(tensor0, tensor1)", "{ {}:5, {x:1,y:1}:1 }", "{ {}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }", - "min(tensor0, tensor1)", "{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }", "{ {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 }"); - - // max - assertEvaluates("{ {x:1}:3, {x:2}:5 }", - "max(tensor0, tensor1)", "{ {x:1}:3 }", "{ {x:2}:5 }"); - assertEvaluates("{ {x:1}:5 }", - "max(tensor0, tensor1)", "{ {x:1}:3 }", "{ {x:1}:5 }"); - assertEvaluates("{ {x:1}:3, {y:1}:5 }", - "max(tensor0, tensor1)", "{ {x:1}:3 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", - "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", - "max(tensor0, tensor1)", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "max(tensor0, tensor1)", "{ {x:1}:5, {x:1,y:1}:1 }", "{ {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", - "max(tensor0, tensor1)", "{ {x:1}:5, {x:1,y:1}:1 }", "{ {z:1}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "max(tensor0, tensor1)", "{ {}:5, {x:1,y:1}:1 }", "{ {y:1,z:1}:7 }"); - assertEvaluates("{ {}:11, {x:1,y:1}:1, {y:1,z:1}:7 }", - "max(tensor0, tensor1)", "{ {}:5, {x:1,y:1}:1 }", "{ {}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {}:5, {x:1}:5, {x:2}:4, {x:1,y:1}:7, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }", - "max(tensor0, tensor1)", "{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }", "{ {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 }"); - // Combined assertEvaluates(String.valueOf(7.5 + 45 + 1.7), "sum( " + // model computation: diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index b65830a4e4a..380ba9177b2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -96,6 +96,8 @@ public interface Tensor { default Tensor add(Tensor argument) { return join(argument, (a, b) -> (a + b )); } default Tensor divide(Tensor argument) { return join(argument, (a, b) -> (a / b )); } default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); } + default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); } + default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); } default Tensor avg(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.avg, dimensions); } default Tensor count(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.count, dimensions); } @@ -107,28 +109,6 @@ public interface Tensor { // ----------------- Old stuff /** - * Returns a tensor which contains the cells of both argument tensors, where the value for - * any <i>matching</i> cell is the min of the two possible values. - * <p> - * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, - * and have the value undefined for any non-shared dimension. - */ - default Tensor min(Tensor argument) { - return new TensorMin(this, argument).result(); - } - - /** - * Returns a tensor which contains the cells of both argument tensors, where the value for - * any <i>matching</i> cell is the max of the two possible values. - * <p> - * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, - * and have the value undefined for any non-shared dimension. - */ - default Tensor max(Tensor argument) { - return new TensorMax(this, argument).result(); - } - - /** * Returns a tensor with the same cells as this and the given function is applied to all its cell values. * * @param function the function to apply to all cells @@ -139,24 +119,6 @@ public interface Tensor { } /** - * Returns a tensor with the given dimension removed and cells which contains the sum of the values - * in the removed dimension. - */ - default Tensor sum(String dimension) { - return new TensorDimensionSum(dimension, this).result(); - } - - /** - * Returns the sum of all the cells of this tensor. - */ - default double sum() { - double sum = 0; - for (Map.Entry<TensorAddress, Double> cell : cells().entrySet()) - sum += cell.getValue(); - return sum; - } - - /** * Returns true if the given tensor is mathematically equal to this: * Both are of type Tensor and have the same content. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java deleted file mode 100644 index ceb003b1615..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Takes the difference between two tensors, see {@link Tensor#subtract} - * - * @author bratseth - */ -class TensorDifference { - - private final Set<String> dimensions; - private final Map<TensorAddress, Double> cells = new HashMap<>(); - - public TensorDifference(Tensor a, Tensor b) { - this.dimensions = TensorOperations.combineDimensions(a, b); - cells.putAll(a.cells()); - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) - cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) - bCell.getValue()); - } - - /** Returns the result of taking this sum */ - public Tensor result() { - return new MapTensor(dimensions, cells); - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java deleted file mode 100644 index d15e5092476..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Takes the max of each cell of two tensors, see {@link Tensor#max} - * - * @author bratseth - */ -class TensorMax { - - private final Set<String> dimensions; - private final Map<TensorAddress, Double> cells = new HashMap<>(); - - public TensorMax(Tensor a, Tensor b) { - dimensions = TensorOperations.combineDimensions(a, b); - cells.putAll(a.cells()); - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { - Double aValue = a.cells().get(bCell.getKey()); - if (aValue == null) - cells.put(bCell.getKey(), bCell.getValue()); - else - cells.put(bCell.getKey(), Math.max(aValue, bCell.getValue())); - } - } - - /** Returns the result of taking this sum */ - public Tensor result() { - return new MapTensor(dimensions, cells); - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java deleted file mode 100644 index e389dea3883..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Takes the min of each cell of two tensors, see {@link Tensor#min} - * - * @author bratseth - */ -class TensorMin { - - private final Set<String> dimensions; - private final Map<TensorAddress, Double> cells = new HashMap<>(); - - public TensorMin(Tensor a, Tensor b) { - dimensions = TensorOperations.combineDimensions(a, b); - cells.putAll(a.cells()); - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { - Double aValue = a.cells().get(bCell.getKey()); - if (aValue == null) - cells.put(bCell.getKey(), bCell.getValue()); - else - cells.put(bCell.getKey(), Math.min(aValue, bCell.getValue())); - } - } - - /** Returns the result of taking this sum */ - public Tensor result() { return new MapTensor(dimensions, cells); } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java deleted file mode 100644 index 85dfa289bd3..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor; - -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Takes the sum of two tensors, see {@link Tensor#add} - * - * @author bratseth - */ -class TensorSum { - - private final Set<String> dimensions; - private final Map<TensorAddress, Double> cells = new HashMap<>(); - - public TensorSum(Tensor a, Tensor b) { - dimensions = TensorOperations.combineDimensions(a, b); - cells.putAll(a.cells()); - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { - cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) + bCell.getValue()); - } - } - - /** Returns the result of taking this sum */ - public Tensor result() { return new MapTensor(dimensions, cells); } - -} |