diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-03 16:58:14 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-03 16:58:14 -0700 |
commit | 3d070f94949b0e2bd6124bb694989a31014731ca (patch) | |
tree | df818318707a8d7001b529e0fd5e520fb83865f2 /vespajlib | |
parent | 66e2f98fc71338caaa9b0f72ee9109c209b910c6 (diff) |
Convenience method to find smallest/largest cells
Diffstat (limited to 'vespajlib')
3 files changed, 98 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 02f54b5790a..1da013de012 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -739,6 +739,11 @@ public abstract class IndexedTensor implements Tensor { @Override public Double getValue() { return value; } + @Override + public Cell detach() { + return new Cell(getKey(), value); + } + } // TODO: Make dimensionSizes a class diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 806d27ce70e..8b0aaa64551 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -63,7 +63,11 @@ public interface Tensor { /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); - /** Returns the cell of this in some undefined order */ + /** + * Returns the cell of this in some undefined order. + * A cell instances is only valid until next() is called. + * Call detach() on the cell to obtain a long-lived instance. + */ Iterator<Cell> cellIterator(); /** Returns the values of this in some undefined order */ @@ -250,6 +254,44 @@ public interface Tensor { default Tensor sum(String dimension) { return sum(Collections.singletonList(dimension)); } default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); } + // ----------------- non-math query methods (that is, computations not returning a tensor) + + /** Returns the cell(s) of this tensor having the highest value */ + default List<Cell> largest() { + List<Cell> cells = new ArrayList<>(1); + double maxValue = Double.MIN_VALUE; + for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) { + Cell cell = i.next(); + if (cell.getValue() > maxValue) { + cells.clear(); + cells.add(cell.detach()); + maxValue = cell.getDoubleValue(); + } + else if (cell.getValue() == maxValue) { + cells.add(cell.detach()); + } + } + return cells; + } + + /** Returns the cell(s) of this tensor having the lowest value */ + default List<Cell> smallest() { + List<Cell> cells = new ArrayList<>(1); + double minValue = Double.MAX_VALUE; + for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) { + Cell cell = i.next(); + if (cell.getValue() < minValue) { + cells.clear(); + cells.add(cell.detach()); + minValue = cell.getDoubleValue(); + } + else if (cell.getValue() == minValue) { + cells.add(cell.detach()); + } + } + return cells; + } + // ----------------- serialization /** @@ -422,6 +464,13 @@ public interface Tensor { return getKey().hashCode() ^ getValue().hashCode(); // by Map.Entry spec } + public String toString(TensorType type) { return address.toString(type) + ":" + value; } + + /** + * Return a copy of this tensor cell which is valid beyond the lifetime of any iterator state which supplied it. + */ + public Cell detach() { return this; } + } interface Builder { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 3d5d8d1f5ae..c6fbb9c009d 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -15,6 +15,7 @@ import java.util.Collections; import java.util.List; import java.util.Set; import java.util.function.DoubleBinaryOperator; +import java.util.stream.Collectors; import static com.yahoo.tensor.TensorType.Dimension.Type; import static org.junit.Assert.assertEquals; @@ -248,6 +249,48 @@ public class TensorTestCase { Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}")); } + @Test + public void testLargest() { + assertLargest("{d1:l1,d2:l2}:6.0", + "tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}"); + assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0", + "tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l2}:6.0}"); + assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0", + "tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l3}:5.0,{d1:l1,d2:l2}:6.0}"); + assertLargest("{x:1,y:1}:4.0", + "tensor(x[2],y[2]):[[1,2],[3,4]"); + assertLargest("{x:0,y:0}:4.0, {x:1,y:1}:4.0", + "tensor(x[2],y[2]):[[4,2],[3,4]"); + } + + @Test + public void testSmallest() { + assertSmallest("{d1:l1,d2:l1}:5.0", + "tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}"); + assertSmallest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0", + "tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l2}:6.0}"); + assertSmallest("{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:5.0", + "tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l3}:6.0,{d1:l1,d2:l2}:5.0}"); + assertSmallest("{x:0,y:0}:1.0", + "tensor(x[2],y[2]):[[1,2],[3,4]"); + assertSmallest("{x:0,y:1}:2.0", + "tensor(x[2],y[2]):[[4,2],[3,4]"); + } + + private void assertLargest(String expectedCells, String tensorString) { + Tensor tensor = Tensor.from(tensorString); + assertEquals(expectedCells, asString(tensor.largest(), tensor.type())); + } + + private void assertSmallest(String expectedCells, String tensorString) { + Tensor tensor = Tensor.from(tensorString); + assertEquals(expectedCells, asString(tensor.smallest(), tensor.type())); + } + + private String asString(List<Tensor.Cell> cells, TensorType type) { + return cells.stream().map(cell -> cell.toString(type)).collect(Collectors.joining(", ")); + } + private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) { assertEquals(expected, init.modify(op, update.cells())); } |