aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-10 11:39:39 -0800
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-10 11:39:39 -0800
commit4c46e1816d2cdfacd8435ad4d55e831929fc99ba (patch)
treed55a90aeeddcf9265a74e7f16129517e36f45375 /vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
parentb8d2859a9fece15dac2b9260d71dea39f8ce19b3 (diff)
Tensor parsing improvements
- Mixed tensor format parsing (outside expressions) - Validate structure of dense tensor strings
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java74
1 files changed, 57 insertions, 17 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 176ddfefc13..30923976fa5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor {
indexes.next();
// start brackets
- for (int i = 0; i < indexes.rightDimensionsWhichAreAtStart(); i++)
+ for (int i = 0; i < indexes.rightDimensionsAtStart(); i++)
b.append("[");
// value
@@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalStateException("Unexpected value type " + type.valueType());
// end bracket and comma
- for (int i = 0; i < indexes.rightDimensionsWhichAreAtEnd(); i++)
+ for (int i = 0; i < indexes.rightDimensionsAtEnd(); i++)
b.append("]");
if (index < size() - 1)
b.append(", ");
@@ -375,8 +375,22 @@ public abstract class IndexedTensor implements Tensor {
}
+ public interface DirectIndexBuilder {
+
+ TensorType type();
+
+
+
+ /** Sets a value by its <i>standard value order</i> index */
+ void cellByDirectIndex(long index, double value);
+
+ /** Sets a value by its <i>standard value order</i> index */
+ void cellByDirectIndex(long index, float value);
+
+ }
+
/** A bound builder can create the double array directly */
- public static abstract class BoundBuilder extends Builder {
+ public static abstract class BoundBuilder extends Builder implements DirectIndexBuilder {
private DimensionSizes sizes;
@@ -393,14 +407,16 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
}
- BoundBuilder fill(float [] values) {
+
+ BoundBuilder fill(float[] values) {
long index = 0;
for (float value : values) {
cellByDirectIndex(index++, value);
}
return this;
}
- BoundBuilder fill(double [] values) {
+
+ BoundBuilder fill(double[] values) {
long index = 0;
for (double value : values) {
cellByDirectIndex(index++, value);
@@ -410,12 +426,6 @@ public abstract class IndexedTensor implements Tensor {
DimensionSizes sizes() { return sizes; }
- /** Sets a value by its <i>standard value order</i> index */
- public abstract void cellByDirectIndex(long index, double value);
-
- /** Sets a value by its <i>standard value order</i> index */
- public abstract void cellByDirectIndex(long index, float value);
-
}
/**
@@ -869,8 +879,11 @@ public abstract class IndexedTensor implements Tensor {
public abstract void next();
+ /** Returns whether further values are available by calling next() */
+ public abstract boolean hasNext();
+
/** Returns the number of dimensions from the right which are currently at the start position (0) */
- int rightDimensionsWhichAreAtStart() {
+ int rightDimensionsAtStart() {
int dimension = indexes.length - 1;
int atStartCount = 0;
while (dimension >= 0 && indexes[dimension] == 0) {
@@ -881,7 +894,7 @@ public abstract class IndexedTensor implements Tensor {
}
/** Returns the number of dimensions from the right which are currently at the end position */
- int rightDimensionsWhichAreAtEnd() {
+ int rightDimensionsAtEnd() {
int dimension = indexes.length - 1;
int atEndCount = 0;
while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) {
@@ -904,10 +917,15 @@ public abstract class IndexedTensor implements Tensor {
@Override
public void next() {}
+ @Override
+ public boolean hasNext() { return false; }
+
}
private final static class SingleValueIndexes extends Indexes {
+ private boolean exhausted = false;
+
private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) {
super(sourceSizes, iterateSizes, indexes);
}
@@ -916,7 +934,10 @@ public abstract class IndexedTensor implements Tensor {
public long size() { return 1; }
@Override
- public void next() {}
+ public void next() { exhausted = true; }
+
+ @Override
+ public boolean hasNext() { return ! exhausted; }
}
@@ -945,7 +966,7 @@ public abstract class IndexedTensor implements Tensor {
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
*
- * @throws RuntimeException if this is called more times than its size
+ * @throws RuntimeException if this is called when hasNext returns false
*/
@Override
public void next() {
@@ -957,6 +978,15 @@ public abstract class IndexedTensor implements Tensor {
indexes[iterateDimensions.get(iterateDimensionsIndex)]++;
}
+ @Override
+ public boolean hasNext() {
+ for (int iterateDimension : iterateDimensions) {
+ if (indexes[iterateDimension] + 1 < dimensionSizes().size(iterateDimension))
+ return true; // some dimension is not at the end
+ }
+ return false;
+ }
+
}
/** In this case we can reuse the source index computation for the iteration index */
@@ -1016,7 +1046,7 @@ public abstract class IndexedTensor implements Tensor {
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
*
- * @throws RuntimeException if this is called more times than its size
+ * @throws RuntimeException if this is called when hasNext returns false
*/
@Override
public void next() {
@@ -1031,6 +1061,11 @@ public abstract class IndexedTensor implements Tensor {
@Override
long toIterationValueIndex() { return currentIterationValueIndex; }
+ @Override
+ public boolean hasNext() {
+ return indexes[iterateDimension] + 1 < size;
+ }
+
}
/** In this case we only need to keep track of one index */
@@ -1068,7 +1103,7 @@ public abstract class IndexedTensor implements Tensor {
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
*
- * @throws RuntimeException if this is called more times than its size
+ * @throws RuntimeException if this is called when hasNext returns false
*/
@Override
public void next() {
@@ -1077,6 +1112,11 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
+ public boolean hasNext() {
+ return indexes[iterateDimension] + 1 < size;
+ }
+
+ @Override
long toSourceValueIndex() { return currentValueIndex; }
@Override