diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java | 34 |
1 files changed, 21 insertions, 13 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 5072484567d..0cec09157fb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -38,7 +38,7 @@ public class DenseBinaryFormat implements BinaryFormat { if ( ! ( tensor instanceof IndexedTensor)) throw new RuntimeException("The dense format is only supported for indexed tensors"); encodeDimensions(buffer, (IndexedTensor)tensor); - encodeCells(buffer, tensor); + encodeCells(buffer, (IndexedTensor)tensor); } private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) { @@ -49,18 +49,21 @@ public class DenseBinaryFormat implements BinaryFormat { } } - private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { + private void encodeCells(GrowableByteBuffer buffer, IndexedTensor tensor) { switch (serializationValueType) { - case DOUBLE: encodeCells(tensor, buffer::putDouble); break; - case FLOAT: encodeCells(tensor, (i) -> buffer.putFloat(i.floatValue())); break; + case DOUBLE: encodeDoubleCells(tensor, buffer); break; + case FLOAT: encodeFloatCells(tensor, buffer); break; } } - private void encodeCells(Tensor tensor, Consumer<Double> consumer) { - Iterator<Double> i = tensor.valueIterator(); - while (i.hasNext()) { - consumer.accept(i.next()); - } + private void encodeDoubleCells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.putDouble(tensor.get(i)); + } + + private void encodeFloatCells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.putFloat(tensor.getFloat(i)); } @Override @@ -106,14 +109,19 @@ public class DenseBinaryFormat implements BinaryFormat { private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { switch (serializationValueType) { - case DOUBLE: decodeCells(sizes, builder, buffer::getDouble); break; - case FLOAT: decodeCells(sizes, builder, () -> (double)buffer.getFloat()); break; + case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break; + case FLOAT: decodeFloatCells(sizes, builder, buffer); break; } } - private void decodeCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, Supplier<Double> supplier) { + private void decodeDoubleCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.getDouble()); + } + + private void decodeFloatCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, supplier.get()); + builder.cellByDirectIndex(i, buffer.getFloat()); } } |