summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java34
1 files changed, 21 insertions, 13 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 5072484567d..0cec09157fb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -38,7 +38,7 @@ public class DenseBinaryFormat implements BinaryFormat {
if ( ! ( tensor instanceof IndexedTensor))
throw new RuntimeException("The dense format is only supported for indexed tensors");
encodeDimensions(buffer, (IndexedTensor)tensor);
- encodeCells(buffer, tensor);
+ encodeCells(buffer, (IndexedTensor)tensor);
}
private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) {
@@ -49,18 +49,21 @@ public class DenseBinaryFormat implements BinaryFormat {
}
}
- private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
+ private void encodeCells(GrowableByteBuffer buffer, IndexedTensor tensor) {
switch (serializationValueType) {
- case DOUBLE: encodeCells(tensor, buffer::putDouble); break;
- case FLOAT: encodeCells(tensor, (i) -> buffer.putFloat(i.floatValue())); break;
+ case DOUBLE: encodeDoubleCells(tensor, buffer); break;
+ case FLOAT: encodeFloatCells(tensor, buffer); break;
}
}
- private void encodeCells(Tensor tensor, Consumer<Double> consumer) {
- Iterator<Double> i = tensor.valueIterator();
- while (i.hasNext()) {
- consumer.accept(i.next());
- }
+ private void encodeDoubleCells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putDouble(tensor.get(i));
+ }
+
+ private void encodeFloatCells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putFloat(tensor.getFloat(i));
}
@Override
@@ -106,14 +109,19 @@ public class DenseBinaryFormat implements BinaryFormat {
private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
switch (serializationValueType) {
- case DOUBLE: decodeCells(sizes, builder, buffer::getDouble); break;
- case FLOAT: decodeCells(sizes, builder, () -> (double)buffer.getFloat()); break;
+ case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break;
+ case FLOAT: decodeFloatCells(sizes, builder, buffer); break;
}
}
- private void decodeCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, Supplier<Double> supplier) {
+ private void decodeDoubleCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.getDouble());
+ }
+
+ private void decodeFloatCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
for (long i = 0; i < sizes.totalSize(); i++)
- builder.cellByDirectIndex(i, supplier.get());
+ builder.cellByDirectIndex(i, buffer.getFloat());
}
}