summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2019-04-29 10:27:15 +0200
committerGitHub <noreply@github.com>2019-04-29 10:27:15 +0200
commit086cb954dfd88caaa4e7998e4e3d8cd052db7907 (patch)
treeadd15d986496fedc9a2f740cb04bc0506eb0a6d0
parent10310b296e969f09683eb839e519237873fa1c0d (diff)
parent97fcbb11ff8e3ab2173fd09d0af071d06e6629e8 (diff)
Merge pull request #9201 from vespa-engine/bratseth/float-tensor
Bratseth/float tensor
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Transformer.java2
-rw-r--r--vespajlib/abi-spec.json82
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java112
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java112
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java161
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java34
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java11
11 files changed, 445 insertions, 116 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Transformer.java b/linguistics/src/main/java/com/yahoo/language/process/Transformer.java
index 398ddc0262b..46f3c060d4e 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Transformer.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Transformer.java
@@ -6,7 +6,7 @@ import com.yahoo.language.Language;
/**
* Interface for providers of text transformations such as accent removal.
*
- * @author <a href="mailto:mathiasm@yahoo-inc.com">Mathias Mølster Lidal</a>
+ * @author Mathias Mølster Lidal
*/
public interface Transformer {
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 43388e4e18d..4f81f3baea8 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -706,27 +706,77 @@
],
"fields": []
},
- "com.yahoo.tensor.IndexedTensor$BoundBuilder": {
- "superClass": "com.yahoo.tensor.IndexedTensor$Builder",
+ "com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder": {
+ "superClass": "com.yahoo.tensor.IndexedTensor$BoundBuilder",
"interfaces": [],
"attributes": [
"public"
],
"methods": [
+ "public varargs com.yahoo.tensor.IndexedTensor$BoundBuilder cell(float, long[])",
"public varargs com.yahoo.tensor.IndexedTensor$BoundBuilder cell(double, long[])",
"public com.yahoo.tensor.Tensor$Builder$CellBuilder cell()",
+ "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
"public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
"public com.yahoo.tensor.IndexedTensor build()",
+ "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)",
"public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)",
+ "public void cellByDirectIndex(long, float)",
"public void cellByDirectIndex(long, double)",
+ "public bridge synthetic com.yahoo.tensor.IndexedTensor$Builder cell(float, long[])",
"public bridge synthetic com.yahoo.tensor.IndexedTensor$Builder cell(double, long[])",
"public bridge synthetic com.yahoo.tensor.Tensor build()",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)",
"public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(float, long[])",
"public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(double, long[])",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
"public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)"
],
"fields": []
},
+ "com.yahoo.tensor.IndexedFloatTensor$BoundFloatBuilder": {
+ "superClass": "com.yahoo.tensor.IndexedTensor$BoundBuilder",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public varargs com.yahoo.tensor.IndexedTensor$BoundBuilder cell(double, long[])",
+ "public varargs com.yahoo.tensor.IndexedTensor$BoundBuilder cell(float, long[])",
+ "public com.yahoo.tensor.Tensor$Builder$CellBuilder cell()",
+ "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
+ "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
+ "public com.yahoo.tensor.IndexedTensor build()",
+ "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)",
+ "public com.yahoo.tensor.IndexedTensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)",
+ "public void cellByDirectIndex(long, double)",
+ "public void cellByDirectIndex(long, float)",
+ "public bridge synthetic com.yahoo.tensor.IndexedTensor$Builder cell(float, long[])",
+ "public bridge synthetic com.yahoo.tensor.IndexedTensor$Builder cell(double, long[])",
+ "public bridge synthetic com.yahoo.tensor.Tensor build()",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(float, long[])",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(double, long[])",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)"
+ ],
+ "fields": []
+ },
+ "com.yahoo.tensor.IndexedTensor$BoundBuilder": {
+ "superClass": "com.yahoo.tensor.IndexedTensor$Builder",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "abstract"
+ ],
+ "methods": [
+ "public abstract void cellByDirectIndex(long, double)",
+ "public abstract void cellByDirectIndex(long, float)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.IndexedTensor$Builder": {
"superClass": "java.lang.Object",
"interfaces": [
@@ -740,9 +790,11 @@
"public static com.yahoo.tensor.IndexedTensor$Builder of(com.yahoo.tensor.TensorType)",
"public static com.yahoo.tensor.IndexedTensor$Builder of(com.yahoo.tensor.TensorType, com.yahoo.tensor.DimensionSizes)",
"public varargs abstract com.yahoo.tensor.IndexedTensor$Builder cell(double, long[])",
+ "public varargs abstract com.yahoo.tensor.IndexedTensor$Builder cell(float, long[])",
"public com.yahoo.tensor.TensorType type()",
"public abstract com.yahoo.tensor.IndexedTensor build()",
"public bridge synthetic com.yahoo.tensor.Tensor build()",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(float, long[])",
"public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(double, long[])"
],
"fields": []
@@ -792,25 +844,26 @@
"com.yahoo.tensor.Tensor"
],
"attributes": [
- "public"
+ "public",
+ "abstract"
],
"methods": [
- "public long size()",
"public java.util.Iterator cellIterator()",
"public com.yahoo.tensor.IndexedTensor$SubspaceIterator cellIterator(com.yahoo.tensor.PartialAddress, com.yahoo.tensor.DimensionSizes)",
"public java.util.Iterator valueIterator()",
"public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)",
"public java.util.Iterator subspaceIterator(java.util.Set)",
"public varargs double get(long[])",
+ "public varargs float getFloat(long[])",
"public double get(com.yahoo.tensor.TensorAddress)",
- "public double get(long)",
+ "public abstract double get(long)",
+ "public abstract float getFloat(long)",
"public com.yahoo.tensor.TensorType type()",
- "public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
+ "public abstract com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public com.yahoo.tensor.Tensor remove(java.util.Set)",
- "public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
"public bridge synthetic com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)"
@@ -829,11 +882,15 @@
"public static com.yahoo.tensor.MappedTensor$Builder of(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor$Builder$CellBuilder cell()",
"public com.yahoo.tensor.TensorType type()",
+ "public com.yahoo.tensor.MappedTensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
"public com.yahoo.tensor.MappedTensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
+ "public varargs com.yahoo.tensor.MappedTensor$Builder cell(float, long[])",
"public varargs com.yahoo.tensor.MappedTensor$Builder cell(double, long[])",
"public com.yahoo.tensor.MappedTensor build()",
"public bridge synthetic com.yahoo.tensor.Tensor build()",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(float, long[])",
"public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(double, long[])",
+ "public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
"public bridge synthetic com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)"
],
"fields": []
@@ -870,6 +927,7 @@
],
"methods": [
"public long denseSubspaceSize()",
+ "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
"public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])",
"public com.yahoo.tensor.MixedTensor build()",
@@ -889,6 +947,7 @@
"methods": [
"public static com.yahoo.tensor.MixedTensor$Builder of(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.TensorType type()",
+ "public varargs com.yahoo.tensor.Tensor$Builder cell(float, long[])",
"public varargs com.yahoo.tensor.Tensor$Builder cell(double, long[])",
"public com.yahoo.tensor.Tensor$Builder$CellBuilder cell()",
"public abstract com.yahoo.tensor.MixedTensor build()",
@@ -917,6 +976,7 @@
"public"
],
"methods": [
+ "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
"public com.yahoo.tensor.MixedTensor build()",
"public void trackBounds(com.yahoo.tensor.TensorAddress)",
@@ -982,7 +1042,8 @@
"methods": [
"public com.yahoo.tensor.Tensor$Builder$CellBuilder label(java.lang.String, java.lang.String)",
"public com.yahoo.tensor.Tensor$Builder$CellBuilder label(java.lang.String, long)",
- "public com.yahoo.tensor.Tensor$Builder value(double)"
+ "public com.yahoo.tensor.Tensor$Builder value(double)",
+ "public com.yahoo.tensor.Tensor$Builder value(float)"
],
"fields": []
},
@@ -1000,8 +1061,11 @@
"public abstract com.yahoo.tensor.TensorType type()",
"public abstract com.yahoo.tensor.Tensor$Builder$CellBuilder cell()",
"public abstract com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
+ "public abstract com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
"public varargs abstract com.yahoo.tensor.Tensor$Builder cell(double, long[])",
+ "public varargs abstract com.yahoo.tensor.Tensor$Builder cell(float, long[])",
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)",
+ "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)",
"public abstract com.yahoo.tensor.Tensor build()"
],
"fields": []
@@ -1017,6 +1081,8 @@
"methods": [
"public com.yahoo.tensor.TensorAddress getKey()",
"public java.lang.Double getValue()",
+ "public float getFloatValue()",
+ "public double getDoubleValue()",
"public java.lang.Double setValue(java.lang.Double)",
"public boolean equals(java.lang.Object)",
"public int hashCode()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
new file mode 100644
index 00000000000..285837a1bc6
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
@@ -0,0 +1,112 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.Arrays;
+
+/**
+ * An indexed tensor implementation holding values as doubles
+ *
+ * @author bratseth
+ */
+class IndexedDoubleTensor extends IndexedTensor {
+
+ private final double[] values;
+
+ IndexedDoubleTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) {
+ super(type, dimensionSizes);
+ this.values = values;
+ }
+
+ @Override
+ public long size() {
+ return values.length;
+ }
+
+ @Override
+ public double get(long valueIndex) { return values[(int)valueIndex]; }
+
+ @Override
+ public float getFloat(long valueIndex) { return (float)get(valueIndex); }
+
+ @Override
+ public IndexedTensor withType(TensorType type) {
+ throwOnIncompatibleType(type);
+ return new IndexedDoubleTensor(type, dimensionSizes(), values);
+ }
+
+ @Override
+ public int hashCode() { return Arrays.hashCode(values); }
+
+ /** A bound builder can create the double array directly */
+ public static class BoundDoubleBuilder extends BoundBuilder {
+
+ private double[] values;
+
+ BoundDoubleBuilder(TensorType type, DimensionSizes sizes) {
+ super(type, sizes);
+ values = new double[(int)sizes.totalSize()];
+ }
+
+ @Override
+ public IndexedTensor.BoundBuilder cell(float value, long ... indexes) {
+ return cell((double)value, indexes);
+ }
+
+ @Override
+ public IndexedTensor.BoundBuilder cell(double value, long ... indexes) {
+ values[(int)toValueIndex(indexes, sizes())] = value;
+ return this;
+ }
+
+ @Override
+ public CellBuilder cell() {
+ return new CellBuilder(type, this);
+ }
+
+ @Override
+ public Builder cell(TensorAddress address, float value) {
+ return cell(address, (double)value);
+ }
+
+ @Override
+ public Builder cell(TensorAddress address, double value) {
+ values[(int)toValueIndex(address, sizes())] = value;
+ return this;
+ }
+
+ @Override
+ public IndexedTensor build() {
+ IndexedTensor tensor = new IndexedDoubleTensor(type, sizes(), values);
+ // prevent further modification
+ values = null;
+ return tensor;
+ }
+
+ @Override
+ public Builder cell(Cell cell, float value) {
+ return cell(cell, (double)value);
+ }
+
+ @Override
+ public Builder cell(Cell cell, double value) {
+ long directIndex = cell.getDirectIndex();
+ if (directIndex >= 0) // optimization
+ values[(int)directIndex] = value;
+ else
+ super.cell(cell, value);
+ return this;
+ }
+
+ @Override
+ public void cellByDirectIndex(long index, float value) {
+ cellByDirectIndex(index, (double)value);
+ }
+
+ @Override
+ public void cellByDirectIndex(long index, double value) {
+ values[(int)index] = value;
+ }
+
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
new file mode 100644
index 00000000000..8f8c24c8421
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
@@ -0,0 +1,112 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.Arrays;
+
+/**
+ * An indexed tensor implementation holding values as floats
+ *
+ * @author bratseth
+ */
+class IndexedFloatTensor extends IndexedTensor {
+
+ private final float[] values;
+
+ IndexedFloatTensor(TensorType type, DimensionSizes dimensionSizes, float[] values) {
+ super(type, dimensionSizes);
+ this.values = values;
+ }
+
+ @Override
+ public long size() {
+ return values.length;
+ }
+
+ @Override
+ public double get(long valueIndex) { return getFloat(valueIndex); }
+
+ @Override
+ public float getFloat(long valueIndex) { return values[(int)valueIndex]; }
+
+ @Override
+ public IndexedTensor withType(TensorType type) {
+ throwOnIncompatibleType(type);
+ return new IndexedFloatTensor(type, dimensionSizes(), values);
+ }
+
+ @Override
+ public int hashCode() { return Arrays.hashCode(values); }
+
+ /** A bound builder can create the float array directly */
+ public static class BoundFloatBuilder extends BoundBuilder {
+
+ private float[] values;
+
+ BoundFloatBuilder(TensorType type, DimensionSizes sizes) {
+ super(type, sizes);
+ values = new float[(int)sizes.totalSize()];
+ }
+
+ @Override
+ public IndexedTensor.BoundBuilder cell(double value, long ... indexes) {
+ return cell((float)value, indexes);
+ }
+
+ @Override
+ public IndexedTensor.BoundBuilder cell(float value, long ... indexes) {
+ values[(int)toValueIndex(indexes, sizes())] = value;
+ return this;
+ }
+
+ @Override
+ public CellBuilder cell() {
+ return new CellBuilder(type, this);
+ }
+
+ @Override
+ public Builder cell(TensorAddress address, double value) {
+ return cell(address, (float)value);
+ }
+
+ @Override
+ public Builder cell(TensorAddress address, float value) {
+ values[(int)toValueIndex(address, sizes())] = value;
+ return this;
+ }
+
+ @Override
+ public IndexedTensor build() {
+ IndexedTensor tensor = new IndexedFloatTensor(type, sizes(), values);
+ // prevent further modification
+ values = null;
+ return tensor;
+ }
+
+ @Override
+ public Builder cell(Cell cell, double value) {
+ return cell(cell, (float)value);
+ }
+
+ @Override
+ public Builder cell(Cell cell, float value) {
+ long directIndex = cell.getDirectIndex();
+ if (directIndex >= 0) // optimization
+ values[(int)directIndex] = value;
+ else
+ super.cell(cell, value);
+ return this;
+ }
+
+ @Override
+ public void cellByDirectIndex(long index, double value) {
+ cellByDirectIndex(index, (float)value);
+ }
+
+ @Override
+ public void cellByDirectIndex(long index, float value) {
+ values[(int)index] = value;
+ }
+
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 38d832d01c2..19edfc0269e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -20,7 +20,7 @@ import java.util.function.DoubleBinaryOperator;
*
* @author bratseth
*/
-public class IndexedTensor implements Tensor {
+public abstract class IndexedTensor implements Tensor {
/** The prescribed and possibly abstract type this is an instance of */
private final TensorType type;
@@ -28,17 +28,9 @@ public class IndexedTensor implements Tensor {
/** The sizes of the dimensions of this in the order of the dimensions of the type */
private final DimensionSizes dimensionSizes;
- private final double[] values;
-
- private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) {
+ IndexedTensor(TensorType type, DimensionSizes dimensionSizes) {
this.type = type;
this.dimensionSizes = dimensionSizes;
- this.values = values;
- }
-
- @Override
- public long size() {
- return values.length;
}
/**
@@ -96,13 +88,23 @@ public class IndexedTensor implements Tensor {
}
/**
- * Returns the value at the given indexes
+ * Returns the value at the given indexes as a double
*
* @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(long ... indexes) {
- return values[(int)toValueIndex(indexes, dimensionSizes)];
+ return get((int)toValueIndex(indexes, dimensionSizes));
+ }
+
+ /**
+ * Returns the value at the given indexes as a float
+ *
+ * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
+ * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
+ */
+ public float getFloat(long ... indexes) {
+ return getFloat((int)toValueIndex(indexes, dimensionSizes));
}
/** Returns the value at this address, or NaN if there is no value at this address */
@@ -110,7 +112,7 @@ public class IndexedTensor implements Tensor {
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
- return values[(int)toValueIndex(address, dimensionSizes)];
+ return get((int)toValueIndex(address, dimensionSizes));
}
catch (IndexOutOfBoundsException e) {
return Double.NaN;
@@ -118,15 +120,24 @@ public class IndexedTensor implements Tensor {
}
/**
- * Returns the value at the given index by direct lookup. Only use
+ * Returns the value at the given index as a double by direct lookup. Only use
* if you know the underlying data layout.
*
* @param valueIndex the direct index into the underlying data.
* @throws IndexOutOfBoundsException if index is out of bounds
*/
- public double get(long valueIndex) { return values[(int)valueIndex]; }
+ public abstract double get(long valueIndex);
- private static long toValueIndex(long[] indexes, DimensionSizes sizes) {
+ /**
+ * Returns the value at the given index as a float by direct lookup. Only use
+ * if you know the underlying data layout.
+ *
+ * @param valueIndex the direct index into the underlying data.
+ * @throws IndexOutOfBoundsException if index is out of bounds
+ */
+ public abstract float getFloat(long valueIndex);
+
+ static long toValueIndex(long[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
@@ -140,7 +151,7 @@ public class IndexedTensor implements Tensor {
return valueIndex;
}
- private static long toValueIndex(TensorAddress address, DimensionSizes sizes) {
+ static long toValueIndex(TensorAddress address, DimensionSizes sizes) {
if (address.isEmpty()) return 0;
long valueIndex = 0;
@@ -160,17 +171,17 @@ public class IndexedTensor implements Tensor {
return product;
}
+ void throwOnIncompatibleType(TensorType type) {
+ if ( ! this.type().isRenamableTo(type))
+ throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
+ ": Types are not compatible");
+ }
+
@Override
public TensorType type() { return type; }
@Override
- public IndexedTensor withType(TensorType type) {
- if (!this.type.isRenamableTo(type)) {
- throw new IllegalArgumentException("IndexedTensor.withType: types are not compatible. Current type: '" +
- this.type.toString() + "', requested type: '" + type.toString() + "'");
- }
- return new IndexedTensor(type, dimensionSizes, values);
- }
+ public abstract IndexedTensor withType(TensorType type);
public DimensionSizes dimensionSizes() {
return dimensionSizes;
@@ -179,13 +190,13 @@ public class IndexedTensor implements Tensor {
@Override
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
- return Collections.singletonMap(TensorAddress.of(), values[0]);
+ return Collections.singletonMap(TensorAddress.of(), get(0));
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
- Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
- for (long i = 0; i < values.length; i++) {
+ Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size());
+ for (long i = 0; i < size(); i++) {
indexes.next();
- builder.put(indexes.toAddress(), values[(int)i]);
+ builder.put(indexes.toAddress(), get(i));
}
return builder.build();
}
@@ -201,9 +212,6 @@ public class IndexedTensor implements Tensor {
}
@Override
- public int hashCode() { return Arrays.hashCode(values); }
-
- @Override
public String toString() { return Tensor.toStandardString(this); }
@Override
@@ -222,7 +230,7 @@ public class IndexedTensor implements Tensor {
public static Builder of(TensorType type) {
if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
- return new BoundBuilder(type);
+ return of(type, BoundBuilder.dimensionSizesOf(type));
else
return new UnboundBuilder(type);
}
@@ -235,8 +243,8 @@ public class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes) {
// validate
if (sizes.dimensions() != type.dimensions().size())
- throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
- "for " + type);
+ throw new IllegalArgumentException(sizes.dimensions() +
+ " is the wrong number of dimensions for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional<Long> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
@@ -245,10 +253,16 @@ public class IndexedTensor implements Tensor {
" but cannot be larger than " + size.get() + " in " + type);
}
- return new BoundBuilder(type, sizes);
+ if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ else if (type.valueType() == TensorType.Value.DOUBLE)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ else
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
}
public abstract Builder cell(double value, long ... indexes);
+ public abstract Builder cell(float value, long ... indexes);
@Override
public TensorType type() { return type; }
@@ -259,74 +273,29 @@ public class IndexedTensor implements Tensor {
}
/** A bound builder can create the double array directly */
- public static class BoundBuilder extends Builder {
+ public static abstract class BoundBuilder extends Builder {
private DimensionSizes sizes;
- private double[] values;
- private BoundBuilder(TensorType type) {
- this(type, dimensionSizesOf(type));
- }
-
- static DimensionSizes dimensionSizesOf(TensorType type) {
+ private static DimensionSizes dimensionSizesOf(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < type.dimensions().size(); i++)
b.set(i, type.dimensions().get(i).size().get());
return b.build();
}
- private BoundBuilder(TensorType type, DimensionSizes sizes) {
+ BoundBuilder(TensorType type, DimensionSizes sizes) {
super(type);
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = new double[(int)sizes.totalSize()];
}
- @Override
- public BoundBuilder cell(double value, long ... indexes) {
- values[(int)toValueIndex(indexes, sizes)] = value;
- return this;
- }
-
- @Override
- public CellBuilder cell() {
- return new CellBuilder(type, this);
- }
-
- @Override
- public Builder cell(TensorAddress address, double value) {
- values[(int)toValueIndex(address, sizes)] = value;
- return this;
- }
-
- @Override
- public IndexedTensor build() {
- IndexedTensor tensor = new IndexedTensor(type, sizes, values);
- // prevent further modification
- sizes = null;
- values = null;
- return tensor;
- }
+ DimensionSizes sizes() { return sizes; }
- @Override
- public Builder cell(Cell cell, double value) {
- long directIndex = cell.getDirectIndex();
- if (directIndex >= 0) // optimization
- values[(int)directIndex] = value;
- else
- super.cell(cell, value);
- return this;
- }
+ public abstract void cellByDirectIndex(long index, double value);
- /**
- * Set a cell value by the index in the internal layout of this cell.
- * This requires knowledge of the internal layout of cells in this implementation, and should therefore
- * probably not be used (but when it can be used it is fast).
- */
- public void cellByDirectIndex(long index, double value) {
- values[(int)index] = value;
- }
+ public abstract void cellByDirectIndex(long index, float value);
}
@@ -348,12 +317,12 @@ public class IndexedTensor implements Tensor {
if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values");
if (type.dimensions().isEmpty()) // single number
- return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
+ return new IndexedDoubleTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
double[] values = new double[(int)dimensionSizes.totalSize()];
fillValues(0, 0, firstDimension, dimensionSizes, values);
- return new IndexedTensor(type, dimensionSizes, values);
+ return new IndexedDoubleTensor(type, dimensionSizes, values);
}
private DimensionSizes findDimensionSizes(List<Object> firstDimension) {
@@ -406,6 +375,11 @@ public class IndexedTensor implements Tensor {
}
@Override
+ public Builder cell(TensorAddress address, float value) {
+ return cell(address, (double)value);
+ }
+
+ @Override
public Builder cell(TensorAddress address, double value) {
long[] indexes = new long[address.size()];
for (int i = 0; i < address.size(); i++) {
@@ -415,6 +389,11 @@ public class IndexedTensor implements Tensor {
return this;
}
+ @Override
+ public Builder cell(float value, long... indexes) {
+ return cell((double)value, indexes);
+ }
+
/**
* Set a value using an index API. The number of indexes must be the same as the dimensions in the type of this.
* Values can be written in any order but all values needed to make this dense must be provided
@@ -460,7 +439,7 @@ public class IndexedTensor implements Tensor {
private final class CellIterator implements Iterator<Cell> {
private long count = 0;
- private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
+ private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size());
private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN);
@Override
@@ -485,13 +464,13 @@ public class IndexedTensor implements Tensor {
@Override
public boolean hasNext() {
- return count < values.length;
+ return count < size();
}
@Override
public Double next() {
try {
- return values[(int)count++];
+ return get(count++);
}
catch (IndexOutOfBoundsException e) {
throw new NoSuchElementException("No element at position " + count);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 22ceed22d3e..693c4b5f2b0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -115,12 +115,22 @@ public class MappedTensor implements Tensor {
public TensorType type() { return type; }
@Override
+ public Builder cell(TensorAddress address, float value) {
+ return cell(address, (double)value);
+ }
+
+ @Override
public Builder cell(TensorAddress address, double value) {
cells.put(address, value);
return this;
}
@Override
+ public Builder cell(float value, long... labels) {
+ return cell((double)value, labels);
+ }
+
+ @Override
public Builder cell(double value, long... labels) {
cells.put(TensorAddress.of(labels), value);
return this;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index c06cb2a0986..95f64cec0c1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -193,6 +193,11 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor.Builder cell(float value, long... labels) {
+ return cell((double)value, labels);
+ }
+
+ @Override
public Tensor.Builder cell(double value, long... labels) {
throw new UnsupportedOperationException("Not implemented.");
}
@@ -236,6 +241,11 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor.Builder cell(TensorAddress address, float value) {
+ return cell(address, (double)value);
+ }
+
+ @Override
public Tensor.Builder cell(TensorAddress address, double value) {
TensorAddress sparsePart = index.sparsePartialAddress(address);
long denseOffset = index.denseOffset(address);
@@ -293,6 +303,11 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor.Builder cell(TensorAddress address, float value) {
+ return cell(address, (double)value);
+ }
+
+ @Override
public Tensor.Builder cell(TensorAddress address, double value) {
cells.put(address, value);
trackBounds(address);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index eb16801c306..ebb341147cf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -370,9 +370,9 @@ public interface Tensor {
class Cell implements Map.Entry<TensorAddress, Double> {
private final TensorAddress address;
- private final Double value;
+ private final Number value;
- Cell(TensorAddress address, Double value) {
+ Cell(TensorAddress address, Number value) {
this.address = address;
this.value = value;
}
@@ -387,8 +387,15 @@ public interface Tensor {
*/
long getDirectIndex() { return -1; }
+ /** Returns the value as a double */
@Override
- public Double getValue() { return value; }
+ public Double getValue() { return value.doubleValue(); }
+
+ /** Returns the value as a float */
+ public float getFloatValue() { return value.floatValue(); }
+
+ /** Returns the value as a double */
+ public double getDoubleValue() { return value.doubleValue(); }
@Override
public Double setValue(Double value) {
@@ -446,9 +453,11 @@ public interface Tensor {
/** Add a cell */
Builder cell(TensorAddress address, double value);
+ Builder cell(TensorAddress address, float value);
/** Add a cell */
Builder cell(double value, long ... labels);
+ Builder cell(float value, long ... labels);
/**
* Add a cell
@@ -459,6 +468,9 @@ public interface Tensor {
default Builder cell(Cell cell, double value) {
return cell(cell.getKey(), value);
}
+ default Builder cell(Cell cell, float value) {
+ return cell(cell.getKey(), value);
+ }
Tensor build();
@@ -484,6 +496,9 @@ public interface Tensor {
public Builder value(double cellValue) {
return tensorBuilder.cell(addressBuilder.build(), cellValue);
}
+ public Builder value(float cellValue) {
+ return tensorBuilder.cell(addressBuilder.build(), cellValue);
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index df78f3dfc3a..b1c7a2341c0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -143,6 +143,7 @@ public class TensorType {
}
private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) {
+ if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed
if (generalization.dimensions().size() != this.dimensions().size()) return false;
for (int i = 0; i < generalization.dimensions().size(); i++) {
Dimension thisDimension = this.dimensions().get(i);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 5072484567d..0cec09157fb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -38,7 +38,7 @@ public class DenseBinaryFormat implements BinaryFormat {
if ( ! ( tensor instanceof IndexedTensor))
throw new RuntimeException("The dense format is only supported for indexed tensors");
encodeDimensions(buffer, (IndexedTensor)tensor);
- encodeCells(buffer, tensor);
+ encodeCells(buffer, (IndexedTensor)tensor);
}
private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) {
@@ -49,18 +49,21 @@ public class DenseBinaryFormat implements BinaryFormat {
}
}
- private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
+ private void encodeCells(GrowableByteBuffer buffer, IndexedTensor tensor) {
switch (serializationValueType) {
- case DOUBLE: encodeCells(tensor, buffer::putDouble); break;
- case FLOAT: encodeCells(tensor, (i) -> buffer.putFloat(i.floatValue())); break;
+ case DOUBLE: encodeDoubleCells(tensor, buffer); break;
+ case FLOAT: encodeFloatCells(tensor, buffer); break;
}
}
- private void encodeCells(Tensor tensor, Consumer<Double> consumer) {
- Iterator<Double> i = tensor.valueIterator();
- while (i.hasNext()) {
- consumer.accept(i.next());
- }
+ private void encodeDoubleCells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putDouble(tensor.get(i));
+ }
+
+ private void encodeFloatCells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putFloat(tensor.getFloat(i));
}
@Override
@@ -106,14 +109,19 @@ public class DenseBinaryFormat implements BinaryFormat {
private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
switch (serializationValueType) {
- case DOUBLE: decodeCells(sizes, builder, buffer::getDouble); break;
- case FLOAT: decodeCells(sizes, builder, () -> (double)buffer.getFloat()); break;
+ case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break;
+ case FLOAT: decodeFloatCells(sizes, builder, buffer); break;
}
}
- private void decodeCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, Supplier<Double> supplier) {
+ private void decodeDoubleCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.getDouble());
+ }
+
+ private void decodeFloatCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
for (long i = 0; i < sizes.totalSize(); i++)
- builder.cellByDirectIndex(i, supplier.get());
+ builder.cellByDirectIndex(i, buffer.getFloat());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 02d16e6f3e4..b01d171792c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -37,6 +37,17 @@ public class TensorTestCase {
}
@Test
+ public void testValueTypes() {
+ assertEquals(Tensor.from("tensor<double>(x[1]):{{x:0}:5}").getClass(), IndexedDoubleTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<double>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedDoubleTensor.class);
+
+ assertEquals(Tensor.from("tensor<float>(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedFloatTensor.class);
+ }
+
+ @Test
public void testParseError() {
try {
Tensor.from("--");