summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java29
1 files changed, 26 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 0cec09157fb..1567c95c9fa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -7,10 +7,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import java.util.Iterator;
import java.util.Optional;
-import java.util.function.Consumer;
-import java.util.function.Supplier;
/**
* Implementation of a dense binary format for a tensor on the form:
@@ -53,6 +50,8 @@ public class DenseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: encodeDoubleCells(tensor, buffer); break;
case FLOAT: encodeFloatCells(tensor, buffer); break;
+ case BFLOAT16: encodeBFloat16Cells(tensor, buffer); break;
+ case INT8: encodeInt8Cells(tensor, buffer); break;
}
}
@@ -66,6 +65,16 @@ public class DenseBinaryFormat implements BinaryFormat {
buffer.putFloat(tensor.getFloat(i));
}
+ private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i)));
+ }
+
+ private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.put((byte) tensor.getFloat(i));
+ }
+
@Override
public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) {
TensorType type;
@@ -111,6 +120,8 @@ public class DenseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break;
case FLOAT: decodeFloatCells(sizes, builder, buffer); break;
+ case BFLOAT16: decodeBFloat16Cells(sizes, builder, buffer); break;
+ case INT8: decodeInt8Cells(sizes, builder, buffer); break;
}
}
@@ -124,4 +135,16 @@ public class DenseBinaryFormat implements BinaryFormat {
builder.cellByDirectIndex(i, buffer.getFloat());
}
+ private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++) {
+ builder.cellByDirectIndex(i, TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort()));
+ }
+ }
+
+ private void decodeInt8Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++) {
+ builder.cellByDirectIndex(i, (float) buffer.get());
+ }
+ }
+
}