summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-25 16:40:00 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-25 16:40:00 +0200
commit7ef86b1fb25f2268d00fa3af87bc1e594de0b1b3 (patch)
tree5af4bc2b63e291b7e80d2ffc3ea85b5dfdf2b044 /vespajlib
parenta8949c869c613d671886b87ab684b2dfef9d9ca5 (diff)
Split values into IndexedDoubleTensor subclass
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java46
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java51
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java1
4 files changed, 68 insertions, 39 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 43388e4e18d..e4b6162eeca 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -792,10 +792,10 @@
"com.yahoo.tensor.Tensor"
],
"attributes": [
- "public"
+ "public",
+ "abstract"
],
"methods": [
- "public long size()",
"public java.util.Iterator cellIterator()",
"public com.yahoo.tensor.IndexedTensor$SubspaceIterator cellIterator(com.yahoo.tensor.PartialAddress, com.yahoo.tensor.DimensionSizes)",
"public java.util.Iterator valueIterator()",
@@ -803,14 +803,13 @@
"public java.util.Iterator subspaceIterator(java.util.Set)",
"public varargs double get(long[])",
"public double get(com.yahoo.tensor.TensorAddress)",
- "public double get(long)",
+ "public abstract double get(long)",
"public com.yahoo.tensor.TensorType type()",
- "public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
+ "public abstract com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public com.yahoo.tensor.Tensor remove(java.util.Set)",
- "public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
"public bridge synthetic com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)"
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
new file mode 100644
index 00000000000..27cecdab80c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
@@ -0,0 +1,46 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.Arrays;
+
+/**
+ * An indexed tensor implementation holding values as doubles
+ *
+ * @author bratseth
+ */
+class IndexedDoubleTensor extends IndexedTensor {
+
+ private final double[] values;
+
+ IndexedDoubleTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) {
+ super(type, dimensionSizes);
+ this.values = values;
+ }
+
+ @Override
+ public long size() {
+ return values.length;
+ }
+
+ /**
+ * Returns the value at the given index by direct lookup. Only use
+ * if you know the underlying data layout.
+ *
+ * @param valueIndex the direct index into the underlying data.
+ * @throws IndexOutOfBoundsException if index is out of bounds
+ */
+ @Override
+ public double get(long valueIndex) { return values[(int)valueIndex]; }
+
+ @Override
+ public IndexedTensor withType(TensorType type) {
+ if ( ! this.type().isRenamableTo(type))
+ throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
+ ": Types are not compatible");
+ return new IndexedDoubleTensor(type, dimensionSizes(), values);
+ }
+
+ @Override
+ public int hashCode() { return Arrays.hashCode(values); }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 38d832d01c2..5f2c04bbd56 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -20,7 +20,7 @@ import java.util.function.DoubleBinaryOperator;
*
* @author bratseth
*/
-public class IndexedTensor implements Tensor {
+public abstract class IndexedTensor implements Tensor {
/** The prescribed and possibly abstract type this is an instance of */
private final TensorType type;
@@ -28,17 +28,9 @@ public class IndexedTensor implements Tensor {
/** The sizes of the dimensions of this in the order of the dimensions of the type */
private final DimensionSizes dimensionSizes;
- private final double[] values;
-
- private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) {
+ IndexedTensor(TensorType type, DimensionSizes dimensionSizes) {
this.type = type;
this.dimensionSizes = dimensionSizes;
- this.values = values;
- }
-
- @Override
- public long size() {
- return values.length;
}
/**
@@ -96,13 +88,13 @@ public class IndexedTensor implements Tensor {
}
/**
- * Returns the value at the given indexes
+ * Returns the value at the given indexes as a double
*
* @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(long ... indexes) {
- return values[(int)toValueIndex(indexes, dimensionSizes)];
+ return get((int)toValueIndex(indexes, dimensionSizes));
}
/** Returns the value at this address, or NaN if there is no value at this address */
@@ -110,7 +102,7 @@ public class IndexedTensor implements Tensor {
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
- return values[(int)toValueIndex(address, dimensionSizes)];
+ return get((int)toValueIndex(address, dimensionSizes));
}
catch (IndexOutOfBoundsException e) {
return Double.NaN;
@@ -124,7 +116,7 @@ public class IndexedTensor implements Tensor {
* @param valueIndex the direct index into the underlying data.
* @throws IndexOutOfBoundsException if index is out of bounds
*/
- public double get(long valueIndex) { return values[(int)valueIndex]; }
+ public abstract double get(long valueIndex);
private static long toValueIndex(long[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
@@ -164,13 +156,7 @@ public class IndexedTensor implements Tensor {
public TensorType type() { return type; }
@Override
- public IndexedTensor withType(TensorType type) {
- if (!this.type.isRenamableTo(type)) {
- throw new IllegalArgumentException("IndexedTensor.withType: types are not compatible. Current type: '" +
- this.type.toString() + "', requested type: '" + type.toString() + "'");
- }
- return new IndexedTensor(type, dimensionSizes, values);
- }
+ public abstract IndexedTensor withType(TensorType type);
public DimensionSizes dimensionSizes() {
return dimensionSizes;
@@ -179,13 +165,13 @@ public class IndexedTensor implements Tensor {
@Override
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
- return Collections.singletonMap(TensorAddress.of(), values[0]);
+ return Collections.singletonMap(TensorAddress.of(), get(0));
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
- Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
- for (long i = 0; i < values.length; i++) {
+ Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size());
+ for (long i = 0; i < size(); i++) {
indexes.next();
- builder.put(indexes.toAddress(), values[(int)i]);
+ builder.put(indexes.toAddress(), get(i));
}
return builder.build();
}
@@ -201,9 +187,6 @@ public class IndexedTensor implements Tensor {
}
@Override
- public int hashCode() { return Arrays.hashCode(values); }
-
- @Override
public String toString() { return Tensor.toStandardString(this); }
@Override
@@ -302,7 +285,7 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
- IndexedTensor tensor = new IndexedTensor(type, sizes, values);
+ IndexedTensor tensor = new IndexedDoubleTensor(type, sizes, values); // TODO
// prevent further modification
sizes = null;
values = null;
@@ -348,12 +331,12 @@ public class IndexedTensor implements Tensor {
if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values");
if (type.dimensions().isEmpty()) // single number
- return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
+ return new IndexedDoubleTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
double[] values = new double[(int)dimensionSizes.totalSize()];
fillValues(0, 0, firstDimension, dimensionSizes, values);
- return new IndexedTensor(type, dimensionSizes, values);
+ return new IndexedDoubleTensor(type, dimensionSizes, values);
}
private DimensionSizes findDimensionSizes(List<Object> firstDimension) {
@@ -460,7 +443,7 @@ public class IndexedTensor implements Tensor {
private final class CellIterator implements Iterator<Cell> {
private long count = 0;
- private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
+ private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size());
private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN);
@Override
@@ -485,13 +468,13 @@ public class IndexedTensor implements Tensor {
@Override
public boolean hasNext() {
- return count < values.length;
+ return count < size();
}
@Override
public Double next() {
try {
- return values[(int)count++];
+ return get(count++);
}
catch (IndexOutOfBoundsException e) {
throw new NoSuchElementException("No element at position " + count);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index df78f3dfc3a..b1c7a2341c0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -143,6 +143,7 @@ public class TensorType {
}
private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) {
+ if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed
if (generalization.dimensions().size() != this.dimensions().size()) return false;
for (int i = 0; i < generalization.dimensions().size(); i++) {
Dimension thisDimension = this.dimensions().get(i);