aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 09:29:31 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 09:29:31 +0100
commit1d5a6a3ae4e34d21220ec748591965522bb17eae (patch)
tree950ee7781acd5c9085109de15231eb0ac620c908 /vespajlib/src/main
parenta8c35d35066f76f69a9254ee3957b3dd7aefb753 (diff)
- Add sizeAsInt to allow for safe cast from long to int of the size of a tensor.
Diffstat (limited to 'vespajlib/src/main')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java2
7 files changed, 37 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
index 548d39dd767..53f50fc4d02 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
@@ -22,6 +22,10 @@ class IndexedDoubleTensor extends IndexedTensor {
return values.length;
}
+ /** Once we can store more cells than an int we should drop this method. */
+ @Override
+ public int sizeAsInt() { return values.length; }
+
@Override
public double get(long valueIndex) { return values[(int)valueIndex]; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
index 26560a70ac4..3085ef1a843 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
@@ -18,9 +18,11 @@ class IndexedFloatTensor extends IndexedTensor {
}
@Override
- public long size() {
- return values.length;
- }
+ public long size() { return values.length; }
+
+ /** Once we can store more cells than an int we should drop this. */
+ @Override
+ public int sizeAsInt() { return values.length; }
@Override
public double get(long valueIndex) { return getFloat(valueIndex); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index e196569b18f..e529c7f71d2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -31,6 +31,10 @@ public class MappedTensor implements Tensor {
@Override
public long size() { return cells.size(); }
+ /** Once we can store more cells than an int we should drop this. */
+ @Override
+ public int sizeAsInt() { return cells.size(); }
+
@Override
public double get(TensorAddress address) { return cells.getOrDefault(address, 0.0); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 8a4179cdc1a..e44df06ed20 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -64,10 +64,25 @@ public interface Tensor {
/**
* Returns the number of cells in this.
- * TODO Figure how to best return an int instead of a long
- * An int is large enough, and java is far better at int base loops than long
+ * Allows for very large tensors, but if you only handle size in the int range
+ * prefer sizeAsInt().
**/
- long size();
+ default long size() {
+ return sizeAsInt();
+ }
+
+ /**
+ * Safe way to get size as an int and detect when not possible.
+ * Prefer this over size() as
+ * @return size() as an int
+ */
+ default int sizeAsInt() {
+ long sz = size();
+ if (sz > Integer.MAX_VALUE) {
+ throw new IndexOutOfBoundsException("size = " + sz + ", which is too large to fit in an int");
+ }
+ return (int) sz;
+ }
/** Returns the value of a cell, or 0.0 if this cell does not exist */
double get(TensorAddress address);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 8cf88610599..5171cf1e472 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -131,7 +131,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions);
// TODO cells.size() is most likely an overestimate, and might need a better heuristic
// But the upside is larger than the downside.
- Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size());
+ Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index ca9527fd681..32e74c0f132 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -56,22 +56,22 @@ public class DenseBinaryFormat implements BinaryFormat {
}
private void encodeDoubleCells(IndexedTensor tensor, GrowableByteBuffer buffer) {
- for (int i = 0; i < tensor.size(); i++)
+ for (int i = 0; i < tensor.sizeAsInt(); i++)
buffer.putDouble(tensor.get(i));
}
private void encodeFloatCells(IndexedTensor tensor, GrowableByteBuffer buffer) {
- for (int i = 0; i < tensor.size(); i++)
+ for (int i = 0; i < tensor.sizeAsInt(); i++)
buffer.putFloat(tensor.getFloat(i));
}
private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
- for (int i = 0; i < tensor.size(); i++)
+ for (int i = 0; i < tensor.sizeAsInt(); i++)
buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i)));
}
private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
- for (int i = 0; i < tensor.size(); i++)
+ for (int i = 0; i < tensor.sizeAsInt(); i++)
buffer.put((byte) tensor.getFloat(i));
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index bdeb9add41a..3a117e41461 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -48,7 +48,7 @@ class SparseBinaryFormat implements BinaryFormat {
}
private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
- buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation
+ buffer.putInt1_4Bytes(tensor.sizeAsInt()); // XXX: Size truncation
switch (serializationValueType) {
case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;