diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-10 11:39:39 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-10 11:39:39 -0800 |
commit | 4c46e1816d2cdfacd8435ad4d55e831929fc99ba (patch) | |
tree | d55a90aeeddcf9265a74e7f16129517e36f45375 /vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | |
parent | b8d2859a9fece15dac2b9260d71dea39f8ce19b3 (diff) |
Tensor parsing improvements
- Mixed tensor format parsing (outside expressions)
- Validate structure of dense tensor strings
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 74 |
1 files changed, 57 insertions, 17 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 176ddfefc13..30923976fa5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor { indexes.next(); // start brackets - for (int i = 0; i < indexes.rightDimensionsWhichAreAtStart(); i++) + for (int i = 0; i < indexes.rightDimensionsAtStart(); i++) b.append("["); // value @@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalStateException("Unexpected value type " + type.valueType()); // end bracket and comma - for (int i = 0; i < indexes.rightDimensionsWhichAreAtEnd(); i++) + for (int i = 0; i < indexes.rightDimensionsAtEnd(); i++) b.append("]"); if (index < size() - 1) b.append(", "); @@ -375,8 +375,22 @@ public abstract class IndexedTensor implements Tensor { } + public interface DirectIndexBuilder { + + TensorType type(); + + + + /** Sets a value by its <i>standard value order</i> index */ + void cellByDirectIndex(long index, double value); + + /** Sets a value by its <i>standard value order</i> index */ + void cellByDirectIndex(long index, float value); + + } + /** A bound builder can create the double array directly */ - public static abstract class BoundBuilder extends Builder { + public static abstract class BoundBuilder extends Builder implements DirectIndexBuilder { private DimensionSizes sizes; @@ -393,14 +407,16 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; } - BoundBuilder fill(float [] values) { + + BoundBuilder fill(float[] values) { long index = 0; for (float value : values) { cellByDirectIndex(index++, value); } return this; } - BoundBuilder fill(double [] values) { + + BoundBuilder fill(double[] values) { long index = 0; for (double value : values) { cellByDirectIndex(index++, value); @@ -410,12 +426,6 @@ public abstract class IndexedTensor implements Tensor { DimensionSizes sizes() { return sizes; } - /** Sets a value by its <i>standard value order</i> index */ - public abstract void cellByDirectIndex(long index, double value); - - /** Sets a value by its <i>standard value order</i> index */ - public abstract void cellByDirectIndex(long index, float value); - } /** @@ -869,8 +879,11 @@ public abstract class IndexedTensor implements Tensor { public abstract void next(); + /** Returns whether further values are available by calling next() */ + public abstract boolean hasNext(); + /** Returns the number of dimensions from the right which are currently at the start position (0) */ - int rightDimensionsWhichAreAtStart() { + int rightDimensionsAtStart() { int dimension = indexes.length - 1; int atStartCount = 0; while (dimension >= 0 && indexes[dimension] == 0) { @@ -881,7 +894,7 @@ public abstract class IndexedTensor implements Tensor { } /** Returns the number of dimensions from the right which are currently at the end position */ - int rightDimensionsWhichAreAtEnd() { + int rightDimensionsAtEnd() { int dimension = indexes.length - 1; int atEndCount = 0; while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) { @@ -904,10 +917,15 @@ public abstract class IndexedTensor implements Tensor { @Override public void next() {} + @Override + public boolean hasNext() { return false; } + } private final static class SingleValueIndexes extends Indexes { + private boolean exhausted = false; + private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) { super(sourceSizes, iterateSizes, indexes); } @@ -916,7 +934,10 @@ public abstract class IndexedTensor implements Tensor { public long size() { return 1; } @Override - public void next() {} + public void next() { exhausted = true; } + + @Override + public boolean hasNext() { return ! exhausted; } } @@ -945,7 +966,7 @@ public abstract class IndexedTensor implements Tensor { * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. * - * @throws RuntimeException if this is called more times than its size + * @throws RuntimeException if this is called when hasNext returns false */ @Override public void next() { @@ -957,6 +978,15 @@ public abstract class IndexedTensor implements Tensor { indexes[iterateDimensions.get(iterateDimensionsIndex)]++; } + @Override + public boolean hasNext() { + for (int iterateDimension : iterateDimensions) { + if (indexes[iterateDimension] + 1 < dimensionSizes().size(iterateDimension)) + return true; // some dimension is not at the end + } + return false; + } + } /** In this case we can reuse the source index computation for the iteration index */ @@ -1016,7 +1046,7 @@ public abstract class IndexedTensor implements Tensor { * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. * - * @throws RuntimeException if this is called more times than its size + * @throws RuntimeException if this is called when hasNext returns false */ @Override public void next() { @@ -1031,6 +1061,11 @@ public abstract class IndexedTensor implements Tensor { @Override long toIterationValueIndex() { return currentIterationValueIndex; } + @Override + public boolean hasNext() { + return indexes[iterateDimension] + 1 < size; + } + } /** In this case we only need to keep track of one index */ @@ -1068,7 +1103,7 @@ public abstract class IndexedTensor implements Tensor { * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. * - * @throws RuntimeException if this is called more times than its size + * @throws RuntimeException if this is called when hasNext returns false */ @Override public void next() { @@ -1077,6 +1112,11 @@ public abstract class IndexedTensor implements Tensor { } @Override + public boolean hasNext() { + return indexes[iterateDimension] + 1 < size; + } + + @Override long toSourceValueIndex() { return currentValueIndex; } @Override |