summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-03 16:58:14 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-03 16:58:14 -0700
commit3d070f94949b0e2bd6124bb694989a31014731ca (patch)
treedf818318707a8d7001b529e0fd5e520fb83865f2 /vespajlib
parent66e2f98fc71338caaa9b0f72ee9109c209b910c6 (diff)
Convenience method to find smallest/largest cells
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java51
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java43
3 files changed, 98 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 02f54b5790a..1da013de012 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -739,6 +739,11 @@ public abstract class IndexedTensor implements Tensor {
@Override
public Double getValue() { return value; }
+ @Override
+ public Cell detach() {
+ return new Cell(getKey(), value);
+ }
+
}
// TODO: Make dimensionSizes a class
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 806d27ce70e..8b0aaa64551 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -63,7 +63,11 @@ public interface Tensor {
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
- /** Returns the cell of this in some undefined order */
+ /**
+ * Returns the cell of this in some undefined order.
+ * A cell instances is only valid until next() is called.
+ * Call detach() on the cell to obtain a long-lived instance.
+ */
Iterator<Cell> cellIterator();
/** Returns the values of this in some undefined order */
@@ -250,6 +254,44 @@ public interface Tensor {
default Tensor sum(String dimension) { return sum(Collections.singletonList(dimension)); }
default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); }
+ // ----------------- non-math query methods (that is, computations not returning a tensor)
+
+ /** Returns the cell(s) of this tensor having the highest value */
+ default List<Cell> largest() {
+ List<Cell> cells = new ArrayList<>(1);
+ double maxValue = Double.MIN_VALUE;
+ for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) {
+ Cell cell = i.next();
+ if (cell.getValue() > maxValue) {
+ cells.clear();
+ cells.add(cell.detach());
+ maxValue = cell.getDoubleValue();
+ }
+ else if (cell.getValue() == maxValue) {
+ cells.add(cell.detach());
+ }
+ }
+ return cells;
+ }
+
+ /** Returns the cell(s) of this tensor having the lowest value */
+ default List<Cell> smallest() {
+ List<Cell> cells = new ArrayList<>(1);
+ double minValue = Double.MAX_VALUE;
+ for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) {
+ Cell cell = i.next();
+ if (cell.getValue() < minValue) {
+ cells.clear();
+ cells.add(cell.detach());
+ minValue = cell.getDoubleValue();
+ }
+ else if (cell.getValue() == minValue) {
+ cells.add(cell.detach());
+ }
+ }
+ return cells;
+ }
+
// ----------------- serialization
/**
@@ -422,6 +464,13 @@ public interface Tensor {
return getKey().hashCode() ^ getValue().hashCode(); // by Map.Entry spec
}
+ public String toString(TensorType type) { return address.toString(type) + ":" + value; }
+
+ /**
+ * Return a copy of this tensor cell which is valid beyond the lifetime of any iterator state which supplied it.
+ */
+ public Cell detach() { return this; }
+
}
interface Builder {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 3d5d8d1f5ae..c6fbb9c009d 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -15,6 +15,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;
+import java.util.stream.Collectors;
import static com.yahoo.tensor.TensorType.Dimension.Type;
import static org.junit.Assert.assertEquals;
@@ -248,6 +249,48 @@ public class TensorTestCase {
Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"));
}
+ @Test
+ public void testLargest() {
+ assertLargest("{d1:l1,d2:l2}:6.0",
+ "tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}");
+ assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0",
+ "tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l2}:6.0}");
+ assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0",
+ "tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l3}:5.0,{d1:l1,d2:l2}:6.0}");
+ assertLargest("{x:1,y:1}:4.0",
+ "tensor(x[2],y[2]):[[1,2],[3,4]");
+ assertLargest("{x:0,y:0}:4.0, {x:1,y:1}:4.0",
+ "tensor(x[2],y[2]):[[4,2],[3,4]");
+ }
+
+ @Test
+ public void testSmallest() {
+ assertSmallest("{d1:l1,d2:l1}:5.0",
+ "tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}");
+ assertSmallest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0",
+ "tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l2}:6.0}");
+ assertSmallest("{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:5.0",
+ "tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l3}:6.0,{d1:l1,d2:l2}:5.0}");
+ assertSmallest("{x:0,y:0}:1.0",
+ "tensor(x[2],y[2]):[[1,2],[3,4]");
+ assertSmallest("{x:0,y:1}:2.0",
+ "tensor(x[2],y[2]):[[4,2],[3,4]");
+ }
+
+ private void assertLargest(String expectedCells, String tensorString) {
+ Tensor tensor = Tensor.from(tensorString);
+ assertEquals(expectedCells, asString(tensor.largest(), tensor.type()));
+ }
+
+ private void assertSmallest(String expectedCells, String tensorString) {
+ Tensor tensor = Tensor.from(tensorString);
+ assertEquals(expectedCells, asString(tensor.smallest(), tensor.type()));
+ }
+
+ private String asString(List<Tensor.Cell> cells, TensorType type) {
+ return cells.stream().map(cell -> cell.toString(type)).collect(Collectors.joining(", "));
+ }
+
private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) {
assertEquals(expected, init.modify(op, update.cells()));
}