diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 0cec09157fb..1567c95c9fa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -7,10 +7,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.Iterator; import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Supplier; /** * Implementation of a dense binary format for a tensor on the form: @@ -53,6 +50,8 @@ public class DenseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: encodeDoubleCells(tensor, buffer); break; case FLOAT: encodeFloatCells(tensor, buffer); break; + case BFLOAT16: encodeBFloat16Cells(tensor, buffer); break; + case INT8: encodeInt8Cells(tensor, buffer); break; } } @@ -66,6 +65,16 @@ public class DenseBinaryFormat implements BinaryFormat { buffer.putFloat(tensor.getFloat(i)); } + private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i))); + } + + private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.put((byte) tensor.getFloat(i)); + } + @Override public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) { TensorType type; @@ -111,6 +120,8 @@ public class DenseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break; case FLOAT: decodeFloatCells(sizes, builder, buffer); break; + case BFLOAT16: decodeBFloat16Cells(sizes, builder, buffer); break; + case INT8: decodeInt8Cells(sizes, builder, buffer); break; } } @@ -124,4 +135,16 @@ public class DenseBinaryFormat implements BinaryFormat { builder.cellByDirectIndex(i, buffer.getFloat()); } + private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) { + builder.cellByDirectIndex(i, TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); + } + } + + private void decodeInt8Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) { + builder.cellByDirectIndex(i, (float) buffer.get()); + } + } + } |