diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 5 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 51 |
2 files changed, 55 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 02f54b5790a..1da013de012 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -739,6 +739,11 @@ public abstract class IndexedTensor implements Tensor { @Override public Double getValue() { return value; } + @Override + public Cell detach() { + return new Cell(getKey(), value); + } + } // TODO: Make dimensionSizes a class diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 806d27ce70e..8b0aaa64551 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -63,7 +63,11 @@ public interface Tensor { /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); - /** Returns the cell of this in some undefined order */ + /** + * Returns the cell of this in some undefined order. + * A cell instances is only valid until next() is called. + * Call detach() on the cell to obtain a long-lived instance. + */ Iterator<Cell> cellIterator(); /** Returns the values of this in some undefined order */ @@ -250,6 +254,44 @@ public interface Tensor { default Tensor sum(String dimension) { return sum(Collections.singletonList(dimension)); } default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); } + // ----------------- non-math query methods (that is, computations not returning a tensor) + + /** Returns the cell(s) of this tensor having the highest value */ + default List<Cell> largest() { + List<Cell> cells = new ArrayList<>(1); + double maxValue = Double.MIN_VALUE; + for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) { + Cell cell = i.next(); + if (cell.getValue() > maxValue) { + cells.clear(); + cells.add(cell.detach()); + maxValue = cell.getDoubleValue(); + } + else if (cell.getValue() == maxValue) { + cells.add(cell.detach()); + } + } + return cells; + } + + /** Returns the cell(s) of this tensor having the lowest value */ + default List<Cell> smallest() { + List<Cell> cells = new ArrayList<>(1); + double minValue = Double.MAX_VALUE; + for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) { + Cell cell = i.next(); + if (cell.getValue() < minValue) { + cells.clear(); + cells.add(cell.detach()); + minValue = cell.getDoubleValue(); + } + else if (cell.getValue() == minValue) { + cells.add(cell.detach()); + } + } + return cells; + } + // ----------------- serialization /** @@ -422,6 +464,13 @@ public interface Tensor { return getKey().hashCode() ^ getValue().hashCode(); // by Map.Entry spec } + public String toString(TensorType type) { return address.toString(type) + ":" + value; } + + /** + * Return a copy of this tensor cell which is valid beyond the lifetime of any iterator state which supplied it. + */ + public Cell detach() { return this; } + } interface Builder { |