diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 52 |
1 files changed, 51 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index bc351b45b28..8c1f4bda92c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -210,7 +210,36 @@ public abstract class IndexedTensor implements Tensor { } @Override - public String toString() { return Tensor.toStandardString(this); } + public String toString() { + if (type.rank() == 0) return Tensor.toStandardString(this); + if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty())) return Tensor.toStandardString(this); + + Indexes indexes = Indexes.of(dimensionSizes); + + StringBuilder b = new StringBuilder(type.toString()).append(":"); + for (int index = 0; index < size(); index++) { + indexes.next(); + + // start brackets + for (int i = 0; i < indexes.rightDimensionsWhichAreAtStart(); i++) + b.append("["); + + // value + if (type.valueType() == TensorType.Value.DOUBLE) + b.append(get(index)); + else if (type.valueType() == TensorType.Value.FLOAT) + b.append(getFloat(index)); + else + throw new IllegalStateException("Unexpected value type " + type.valueType()); + + // end bracket and comma + for (int i = 0; i < indexes.rightDimensionsWhichAreAtEnd(); i++) + b.append("]"); + if (index < size() - 1) + b.append(", "); + } + return b.toString(); + } @Override public boolean equals(Object other) { @@ -829,6 +858,27 @@ public abstract class IndexedTensor implements Tensor { public abstract void next(); + /** Returns the number of dimensions from the right which are currently at the start position (0) */ + int rightDimensionsWhichAreAtStart() { + int dimension = indexes.length - 1; + int atStartCount = 0; + while (dimension >= 0 && indexes[dimension] == 0) { + atStartCount++; + dimension--; + } + return atStartCount; + } + + /** Returns the number of dimensions from the right which are currently at the end position */ + int rightDimensionsWhichAreAtEnd() { + int dimension = indexes.length - 1; + int atEndCount = 0; + while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) { + atEndCount++; + dimension--; + } + return atEndCount; + } } private final static class EmptyIndexes extends Indexes { |