summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
committerLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
commit049e9a325c8142958909d0464da12a56e5a8f638 (patch)
tree31d857ec4a5ad3415464e480ae473c39224623b2 /vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
parentbccd68f8f9a7eb0830d136f8b034ae4f40cc819c (diff)
Add bfloat16 and int8 tensor cell types in Java
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java26
1 files changed, 26 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 0cec09157fb..edb68025d45 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -53,6 +53,8 @@ public class DenseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: encodeDoubleCells(tensor, buffer); break;
case FLOAT: encodeFloatCells(tensor, buffer); break;
+ case BFLOAT16: encodeBFloat16Cells(tensor, buffer); break;
+ case INT8: encodeInt8Cells(tensor, buffer); break;
}
}
@@ -66,6 +68,16 @@ public class DenseBinaryFormat implements BinaryFormat {
buffer.putFloat(tensor.getFloat(i));
}
+ private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putShort((short)(Float.floatToRawIntBits(tensor.getFloat(i)) >>> 16));
+ }
+
+ private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.put((byte) tensor.getFloat(i));
+ }
+
@Override
public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) {
TensorType type;
@@ -111,6 +123,8 @@ public class DenseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break;
case FLOAT: decodeFloatCells(sizes, builder, buffer); break;
+ case BFLOAT16: decodeBFloat16Cells(sizes, builder, buffer); break;
+ case INT8: decodeInt8Cells(sizes, builder, buffer); break;
}
}
@@ -124,4 +138,16 @@ public class DenseBinaryFormat implements BinaryFormat {
builder.cellByDirectIndex(i, buffer.getFloat());
}
+ private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++) {
+ builder.cellByDirectIndex(i, Float.intBitsToFloat(buffer.getShort() << 16));
+ }
+ }
+
+ private void decodeInt8Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++) {
+ builder.cellByDirectIndex(i, (float) buffer.get());
+ }
+ }
+
}