summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java52
1 files changed, 51 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index bc351b45b28..8c1f4bda92c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -210,7 +210,36 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- public String toString() { return Tensor.toStandardString(this); }
+ public String toString() {
+ if (type.rank() == 0) return Tensor.toStandardString(this);
+ if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty())) return Tensor.toStandardString(this);
+
+ Indexes indexes = Indexes.of(dimensionSizes);
+
+ StringBuilder b = new StringBuilder(type.toString()).append(":");
+ for (int index = 0; index < size(); index++) {
+ indexes.next();
+
+ // start brackets
+ for (int i = 0; i < indexes.rightDimensionsWhichAreAtStart(); i++)
+ b.append("[");
+
+ // value
+ if (type.valueType() == TensorType.Value.DOUBLE)
+ b.append(get(index));
+ else if (type.valueType() == TensorType.Value.FLOAT)
+ b.append(getFloat(index));
+ else
+ throw new IllegalStateException("Unexpected value type " + type.valueType());
+
+ // end bracket and comma
+ for (int i = 0; i < indexes.rightDimensionsWhichAreAtEnd(); i++)
+ b.append("]");
+ if (index < size() - 1)
+ b.append(", ");
+ }
+ return b.toString();
+ }
@Override
public boolean equals(Object other) {
@@ -829,6 +858,27 @@ public abstract class IndexedTensor implements Tensor {
public abstract void next();
+ /** Returns the number of dimensions from the right which are currently at the start position (0) */
+ int rightDimensionsWhichAreAtStart() {
+ int dimension = indexes.length - 1;
+ int atStartCount = 0;
+ while (dimension >= 0 && indexes[dimension] == 0) {
+ atStartCount++;
+ dimension--;
+ }
+ return atStartCount;
+ }
+
+ /** Returns the number of dimensions from the right which are currently at the end position */
+ int rightDimensionsWhichAreAtEnd() {
+ int dimension = indexes.length - 1;
+ int atEndCount = 0;
+ while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) {
+ atEndCount++;
+ dimension--;
+ }
+ return atEndCount;
+ }
}
private final static class EmptyIndexes extends Indexes {