diff options
author | Lester Solbakken <lesters@oath.com> | 2021-04-08 11:24:52 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-04-08 11:24:52 +0200 |
commit | 049e9a325c8142958909d0464da12a56e5a8f638 (patch) | |
tree | 31d857ec4a5ad3415464e480ae473c39224623b2 /vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java | |
parent | bccd68f8f9a7eb0830d136f8b034ae4f40cc819c (diff) |
Add bfloat16 and int8 tensor cell types in Java
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 0cec09157fb..edb68025d45 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -53,6 +53,8 @@ public class DenseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: encodeDoubleCells(tensor, buffer); break; case FLOAT: encodeFloatCells(tensor, buffer); break; + case BFLOAT16: encodeBFloat16Cells(tensor, buffer); break; + case INT8: encodeInt8Cells(tensor, buffer); break; } } @@ -66,6 +68,16 @@ public class DenseBinaryFormat implements BinaryFormat { buffer.putFloat(tensor.getFloat(i)); } + private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.putShort((short)(Float.floatToRawIntBits(tensor.getFloat(i)) >>> 16)); + } + + private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.put((byte) tensor.getFloat(i)); + } + @Override public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) { TensorType type; @@ -111,6 +123,8 @@ public class DenseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break; case FLOAT: decodeFloatCells(sizes, builder, buffer); break; + case BFLOAT16: decodeBFloat16Cells(sizes, builder, buffer); break; + case INT8: decodeInt8Cells(sizes, builder, buffer); break; } } @@ -124,4 +138,16 @@ public class DenseBinaryFormat implements BinaryFormat { builder.cellByDirectIndex(i, buffer.getFloat()); } + private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) { + builder.cellByDirectIndex(i, Float.intBitsToFloat(buffer.getShort() << 16)); + } + } + + private void decodeInt8Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) { + builder.cellByDirectIndex(i, (float) buffer.get()); + } + } + } |