summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java51
2 files changed, 55 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 02f54b5790a..1da013de012 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -739,6 +739,11 @@ public abstract class IndexedTensor implements Tensor {
@Override
public Double getValue() { return value; }
+ @Override
+ public Cell detach() {
+ return new Cell(getKey(), value);
+ }
+
}
// TODO: Make dimensionSizes a class
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 806d27ce70e..8b0aaa64551 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -63,7 +63,11 @@ public interface Tensor {
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
- /** Returns the cell of this in some undefined order */
+ /**
+ * Returns the cell of this in some undefined order.
+ * A cell instances is only valid until next() is called.
+ * Call detach() on the cell to obtain a long-lived instance.
+ */
Iterator<Cell> cellIterator();
/** Returns the values of this in some undefined order */
@@ -250,6 +254,44 @@ public interface Tensor {
default Tensor sum(String dimension) { return sum(Collections.singletonList(dimension)); }
default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); }
+ // ----------------- non-math query methods (that is, computations not returning a tensor)
+
+ /** Returns the cell(s) of this tensor having the highest value */
+ default List<Cell> largest() {
+ List<Cell> cells = new ArrayList<>(1);
+ double maxValue = Double.MIN_VALUE;
+ for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) {
+ Cell cell = i.next();
+ if (cell.getValue() > maxValue) {
+ cells.clear();
+ cells.add(cell.detach());
+ maxValue = cell.getDoubleValue();
+ }
+ else if (cell.getValue() == maxValue) {
+ cells.add(cell.detach());
+ }
+ }
+ return cells;
+ }
+
+ /** Returns the cell(s) of this tensor having the lowest value */
+ default List<Cell> smallest() {
+ List<Cell> cells = new ArrayList<>(1);
+ double minValue = Double.MAX_VALUE;
+ for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) {
+ Cell cell = i.next();
+ if (cell.getValue() < minValue) {
+ cells.clear();
+ cells.add(cell.detach());
+ minValue = cell.getDoubleValue();
+ }
+ else if (cell.getValue() == minValue) {
+ cells.add(cell.detach());
+ }
+ }
+ return cells;
+ }
+
// ----------------- serialization
/**
@@ -422,6 +464,13 @@ public interface Tensor {
return getKey().hashCode() ^ getValue().hashCode(); // by Map.Entry spec
}
+ public String toString(TensorType type) { return address.toString(type) + ":" + value; }
+
+ /**
+ * Return a copy of this tensor cell which is valid beyond the lifetime of any iterator state which supplied it.
+ */
+ public Cell detach() { return this; }
+
}
interface Builder {