aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java48
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java42
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java30
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java29
7 files changed, 10 insertions, 214 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 788f5841df9..536be62541e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.TensorType;
+import java.util.Collections;
import java.util.Optional;
/**
@@ -17,7 +18,7 @@ import java.util.Optional;
*
* @author bratseth
*/
- @Beta
+@Beta
public class TensorValue extends Value {
/** The tensor value of this */
@@ -97,11 +98,11 @@ public class TensorValue extends Value {
}
public Value sum(String dimension) {
- return new TensorValue(value.sum(dimension));
+ return new TensorValue(value.sum(Collections.singletonList(dimension)));
}
public Value sum() {
- return new DoubleValue(value.sum());
+ return new TensorValue(value.sum(Collections.emptyList()));
}
private Tensor asTensor(Value value, String operationName) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index d752fb1fccd..e442d9823d5 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -149,54 +149,14 @@ public class EvaluationTestCase extends junit.framework.TestCase {
"tensor0 - tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:4 }",
"tensor0 / tensor1", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }");
+ assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:7 }",
+ "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:3, {x:2,y:1}:5 }",
+ "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }",
"tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }");
- // min
- assertEvaluates("{ {x:1}:3, {x:2}:5 }",
- "min(tensor0, tensor1)", "{ {x:1}:3 }", "{ {x:2}:5 }");
- assertEvaluates("{ {x:1}:3 }",
- "min(tensor0, tensor1)", "{ {x:1}:3 }", "{ {x:1}:5 }");
- assertEvaluates("{ {x:1}:3, {y:1}:5 }",
- "min(tensor0, tensor1)", "{ {x:1}:3 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
- "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
- "min(tensor0, tensor1)", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "min(tensor0, tensor1)", "{ {x:1}:5, {x:1,y:1}:1 }", "{ {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
- "min(tensor0, tensor1)", "{ {x:1}:5, {x:1,y:1}:1 }", "{ {z:1}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "min(tensor0, tensor1)", "{ {}:5, {x:1,y:1}:1 }", "{ {y:1,z:1}:7 }");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "min(tensor0, tensor1)", "{ {}:5, {x:1,y:1}:1 }", "{ {}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }",
- "min(tensor0, tensor1)", "{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }", "{ {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 }");
-
- // max
- assertEvaluates("{ {x:1}:3, {x:2}:5 }",
- "max(tensor0, tensor1)", "{ {x:1}:3 }", "{ {x:2}:5 }");
- assertEvaluates("{ {x:1}:5 }",
- "max(tensor0, tensor1)", "{ {x:1}:3 }", "{ {x:1}:5 }");
- assertEvaluates("{ {x:1}:3, {y:1}:5 }",
- "max(tensor0, tensor1)", "{ {x:1}:3 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
- "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
- "max(tensor0, tensor1)", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "max(tensor0, tensor1)", "{ {x:1}:5, {x:1,y:1}:1 }", "{ {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
- "max(tensor0, tensor1)", "{ {x:1}:5, {x:1,y:1}:1 }", "{ {z:1}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "max(tensor0, tensor1)", "{ {}:5, {x:1,y:1}:1 }", "{ {y:1,z:1}:7 }");
- assertEvaluates("{ {}:11, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "max(tensor0, tensor1)", "{ {}:5, {x:1,y:1}:1 }", "{ {}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {}:5, {x:1}:5, {x:2}:4, {x:1,y:1}:7, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }",
- "max(tensor0, tensor1)", "{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }", "{ {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 }");
-
// Combined
assertEvaluates(String.valueOf(7.5 + 45 + 1.7),
"sum( " + // model computation:
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index b65830a4e4a..380ba9177b2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -96,6 +96,8 @@ public interface Tensor {
default Tensor add(Tensor argument) { return join(argument, (a, b) -> (a + b )); }
default Tensor divide(Tensor argument) { return join(argument, (a, b) -> (a / b )); }
default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); }
+ default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); }
+ default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); }
default Tensor avg(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.avg, dimensions); }
default Tensor count(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.count, dimensions); }
@@ -107,28 +109,6 @@ public interface Tensor {
// ----------------- Old stuff
/**
- * Returns a tensor which contains the cells of both argument tensors, where the value for
- * any <i>matching</i> cell is the min of the two possible values.
- * <p>
- * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
- * and have the value undefined for any non-shared dimension.
- */
- default Tensor min(Tensor argument) {
- return new TensorMin(this, argument).result();
- }
-
- /**
- * Returns a tensor which contains the cells of both argument tensors, where the value for
- * any <i>matching</i> cell is the max of the two possible values.
- * <p>
- * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
- * and have the value undefined for any non-shared dimension.
- */
- default Tensor max(Tensor argument) {
- return new TensorMax(this, argument).result();
- }
-
- /**
* Returns a tensor with the same cells as this and the given function is applied to all its cell values.
*
* @param function the function to apply to all cells
@@ -139,24 +119,6 @@ public interface Tensor {
}
/**
- * Returns a tensor with the given dimension removed and cells which contains the sum of the values
- * in the removed dimension.
- */
- default Tensor sum(String dimension) {
- return new TensorDimensionSum(dimension, this).result();
- }
-
- /**
- * Returns the sum of all the cells of this tensor.
- */
- default double sum() {
- double sum = 0;
- for (Map.Entry<TensorAddress, Double> cell : cells().entrySet())
- sum += cell.getValue();
- return sum;
- }
-
- /**
* Returns true if the given tensor is mathematically equal to this:
* Both are of type Tensor and have the same content.
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java
deleted file mode 100644
index ceb003b1615..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorDifference.java
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Takes the difference between two tensors, see {@link Tensor#subtract}
- *
- * @author bratseth
- */
-class TensorDifference {
-
- private final Set<String> dimensions;
- private final Map<TensorAddress, Double> cells = new HashMap<>();
-
- public TensorDifference(Tensor a, Tensor b) {
- this.dimensions = TensorOperations.combineDimensions(a, b);
- cells.putAll(a.cells());
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet())
- cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) - bCell.getValue());
- }
-
- /** Returns the result of taking this sum */
- public Tensor result() {
- return new MapTensor(dimensions, cells);
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java
deleted file mode 100644
index d15e5092476..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorMax.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Takes the max of each cell of two tensors, see {@link Tensor#max}
- *
- * @author bratseth
- */
-class TensorMax {
-
- private final Set<String> dimensions;
- private final Map<TensorAddress, Double> cells = new HashMap<>();
-
- public TensorMax(Tensor a, Tensor b) {
- dimensions = TensorOperations.combineDimensions(a, b);
- cells.putAll(a.cells());
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
- Double aValue = a.cells().get(bCell.getKey());
- if (aValue == null)
- cells.put(bCell.getKey(), bCell.getValue());
- else
- cells.put(bCell.getKey(), Math.max(aValue, bCell.getValue()));
- }
- }
-
- /** Returns the result of taking this sum */
- public Tensor result() {
- return new MapTensor(dimensions, cells);
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java
deleted file mode 100644
index e389dea3883..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorMin.java
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Takes the min of each cell of two tensors, see {@link Tensor#min}
- *
- * @author bratseth
- */
-class TensorMin {
-
- private final Set<String> dimensions;
- private final Map<TensorAddress, Double> cells = new HashMap<>();
-
- public TensorMin(Tensor a, Tensor b) {
- dimensions = TensorOperations.combineDimensions(a, b);
- cells.putAll(a.cells());
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
- Double aValue = a.cells().get(bCell.getKey());
- if (aValue == null)
- cells.put(bCell.getKey(), bCell.getValue());
- else
- cells.put(bCell.getKey(), Math.min(aValue, bCell.getValue()));
- }
- }
-
- /** Returns the result of taking this sum */
- public Tensor result() { return new MapTensor(dimensions, cells); }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java
deleted file mode 100644
index 85dfa289bd3..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorSum.java
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Takes the sum of two tensors, see {@link Tensor#add}
- *
- * @author bratseth
- */
-class TensorSum {
-
- private final Set<String> dimensions;
- private final Map<TensorAddress, Double> cells = new HashMap<>();
-
- public TensorSum(Tensor a, Tensor b) {
- dimensions = TensorOperations.combineDimensions(a, b);
- cells.putAll(a.cells());
- for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) {
- cells.put(bCell.getKey(), a.cells().getOrDefault(bCell.getKey(), 0d) + bCell.getValue());
- }
- }
-
- /** Returns the result of taking this sum */
- public Tensor result() { return new MapTensor(dimensions, cells); }
-
-}