diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-18 09:29:31 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-18 09:29:31 +0100 |
commit | 1d5a6a3ae4e34d21220ec748591965522bb17eae (patch) | |
tree | 950ee7781acd5c9085109de15231eb0ac620c908 /vespajlib/src | |
parent | a8c35d35066f76f69a9254ee3957b3dd7aefb753 (diff) |
- Add sizeAsInt to allow for safe cast from long to int of the size of a tensor.
Diffstat (limited to 'vespajlib/src')
7 files changed, 37 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 548d39dd767..53f50fc4d02 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -22,6 +22,10 @@ class IndexedDoubleTensor extends IndexedTensor { return values.length; } + /** Once we can store more cells than an int we should drop this method. */ + @Override + public int sizeAsInt() { return values.length; } + @Override public double get(long valueIndex) { return values[(int)valueIndex]; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 26560a70ac4..3085ef1a843 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -18,9 +18,11 @@ class IndexedFloatTensor extends IndexedTensor { } @Override - public long size() { - return values.length; - } + public long size() { return values.length; } + + /** Once we can store more cells than an int we should drop this. */ + @Override + public int sizeAsInt() { return values.length; } @Override public double get(long valueIndex) { return getFloat(valueIndex); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index e196569b18f..e529c7f71d2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -31,6 +31,10 @@ public class MappedTensor implements Tensor { @Override public long size() { return cells.size(); } + /** Once we can store more cells than an int we should drop this. */ + @Override + public int sizeAsInt() { return cells.size(); } + @Override public double get(TensorAddress address) { return cells.getOrDefault(address, 0.0); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8a4179cdc1a..e44df06ed20 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -64,10 +64,25 @@ public interface Tensor { /** * Returns the number of cells in this. - * TODO Figure how to best return an int instead of a long - * An int is large enough, and java is far better at int base loops than long + * Allows for very large tensors, but if you only handle size in the int range + * prefer sizeAsInt(). **/ - long size(); + default long size() { + return sizeAsInt(); + } + + /** + * Safe way to get size as an int and detect when not possible. + * Prefer this over size() as + * @return size() as an int + */ + default int sizeAsInt() { + long sz = size(); + if (sz > Integer.MAX_VALUE) { + throw new IndexOutOfBoundsException("size = " + sz + ", which is too large to fit in an int"); + } + return (int) sz; + } /** Returns the value of a cell, or 0.0 if this cell does not exist */ double get(TensorAddress address); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 8cf88610599..5171cf1e472 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -131,7 +131,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions); // TODO cells.size() is most likely an overestimate, and might need a better heuristic // But the upside is larger than the downside. - Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size()); + Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index ca9527fd681..32e74c0f132 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -56,22 +56,22 @@ public class DenseBinaryFormat implements BinaryFormat { } private void encodeDoubleCells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.putDouble(tensor.get(i)); } private void encodeFloatCells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.putFloat(tensor.getFloat(i)); } private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i))); } private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.put((byte) tensor.getFloat(i)); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index bdeb9add41a..3a117e41461 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -48,7 +48,7 @@ class SparseBinaryFormat implements BinaryFormat { } private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { - buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation + buffer.putInt1_4Bytes(tensor.sizeAsInt()); // XXX: Size truncation switch (serializationValueType) { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; |