diff options
16 files changed, 298 insertions, 15 deletions
diff --git a/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json b/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json index f6b535e071a..b7710eadf5d 100644 --- a/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json +++ b/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json @@ -20,4 +20,24 @@ {"tensor":{"type":"tensor<float>(x[10],y{})","cells":[]},"binary":["0x07010101790101780A00"]} {"tensor":{"type":"tensor<float>(x{},y[3])","cells":[{"address":{"x":"a","y":0},"value":11},{"address":{"x":"a","y":1},"value":12},{"address":{"x":"a","y":2},"value":13},{"address":{"x":"b","y":0},"value":21},{"address":{"x":"b","y":1},"value":22},{"address":{"x":"b","y":2},"value":23}]},"binary":["0x070101017801017903020161413000004140000041500000016241A8000041B0000041B80000","0x07010101780101790302016241A8000041B0000041B800000161413000004140000041500000"]} {"tensor":{"type":"tensor<float>(x[3],y{})","cells":[{"address":{"x":0,"y":"a"},"value":11},{"address":{"x":0,"y":"b"},"value":21},{"address":{"x":1,"y":"a"},"value":12},{"address":{"x":1,"y":"b"},"value":22},{"address":{"x":2,"y":"a"},"value":13},{"address":{"x":2,"y":"b"},"value":23}]},"binary":["0x070101017901017803020161413000004140000041500000016241A8000041B0000041B80000","0x07010101790101780302016241A8000041B0000041B800000161413000004140000041500000"]} -{"num_tests":22} +{"tensor":{"type":"tensor<bfloat16>(x[3])","cells":[{"address":{"x":0},"value":1},{"address":{"x":1},"value":2},{"address":{"x":2},"value":3}]},"binary":["0x0602010178033F8040004040","0x070200010178033F8040004040"]} +{"tensor":{"type":"tensor<bfloat16>(x[2],y[3])","cells":[{"address":{"x":0,"y":0},"value":11},{"address":{"x":0,"y":1},"value":12},{"address":{"x":0,"y":2},"value":13},{"address":{"x":1,"y":0},"value":21},{"address":{"x":1,"y":1},"value":22},{"address":{"x":1,"y":2},"value":23}]},"binary":["0x06020201780201790341304140415041A841B041B8","0x0702000201780201790341304140415041A841B041B8"]} +{"tensor":{"type":"tensor<bfloat16>(x{})","cells":[]},"binary":["0x050201017800","0x07020101780000"]} +{"tensor":{"type":"tensor<bfloat16>(x{})","cells":[{"address":{"x":"a"},"value":1},{"address":{"x":"b"},"value":2},{"address":{"x":"c"},"value":3}]},"binary":["0x05020101780301613F800162400001634040","0x0702010178000301613F800162400001634040","0x05020101780301613F800163404001624000","0x0702010178000301613F800163404001624000","0x0502010178030162400001613F8001634040","0x070201017800030162400001613F8001634040","0x050201017803016240000163404001613F80","0x07020101780003016240000163404001613F80","0x0502010178030163404001613F8001624000","0x070201017800030163404001613F8001624000","0x050201017803016340400162400001613F80","0x07020101780003016340400162400001613F80"]} +{"tensor":{"type":"tensor<bfloat16>(x{},y{})","cells":[]},"binary":["0x0502020178017900","0x070202017801790000"]} +{"tensor":{"type":"tensor<bfloat16>(x{},y{})","cells":[{"address":{"x":"bar","y":"a"},"value":21},{"address":{"x":"foo","y":"a"},"value":11}]},"binary":["0x050202017801790203666F6F0161413003626172016141A8","0x07020201780179000203666F6F0161413003626172016141A8","0x050202017801790203626172016141A803666F6F01614130","0x07020201780179000203626172016141A803666F6F01614130"]} +{"tensor":{"type":"tensor<bfloat16>(x{},y[10])","cells":[]},"binary":["0x07020101780101790A00"]} +{"tensor":{"type":"tensor<bfloat16>(x[10],y{})","cells":[]},"binary":["0x07020101790101780A00"]} +{"tensor":{"type":"tensor<bfloat16>(x{},y[3])","cells":[{"address":{"x":"a","y":0},"value":11},{"address":{"x":"a","y":1},"value":12},{"address":{"x":"a","y":2},"value":13},{"address":{"x":"b","y":0},"value":21},{"address":{"x":"b","y":1},"value":22},{"address":{"x":"b","y":2},"value":23}]},"binary":["0x070201017801017903020161413041404150016241A841B041B8","0x07020101780101790302016241A841B041B80161413041404150"]} +{"tensor":{"type":"tensor<bfloat16>(x[3],y{})","cells":[{"address":{"x":0,"y":"a"},"value":11},{"address":{"x":0,"y":"b"},"value":21},{"address":{"x":1,"y":"a"},"value":12},{"address":{"x":1,"y":"b"},"value":22},{"address":{"x":2,"y":"a"},"value":13},{"address":{"x":2,"y":"b"},"value":23}]},"binary":["0x070201017901017803020161413041404150016241A841B041B8","0x07020101790101780302016241A841B041B80161413041404150"]} +{"tensor":{"type":"tensor<int8>(x[3])","cells":[{"address":{"x":0},"value":1},{"address":{"x":1},"value":2},{"address":{"x":2},"value":3}]},"binary":["0x060301017803010203","0x07030001017803010203"]} +{"tensor":{"type":"tensor<int8>(x[2],y[3])","cells":[{"address":{"x":0,"y":0},"value":11},{"address":{"x":0,"y":1},"value":12},{"address":{"x":0,"y":2},"value":13},{"address":{"x":1,"y":0},"value":21},{"address":{"x":1,"y":1},"value":22},{"address":{"x":1,"y":2},"value":23}]},"binary":["0x0603020178020179030B0C0D151617","0x070300020178020179030B0C0D151617"]} +{"tensor":{"type":"tensor<int8>(x{})","cells":[]},"binary":["0x050301017800","0x07030101780000"]} +{"tensor":{"type":"tensor<int8>(x{})","cells":[{"address":{"x":"a"},"value":1},{"address":{"x":"b"},"value":2},{"address":{"x":"c"},"value":3}]},"binary":["0x050301017803016101016202016303","0x07030101780003016101016202016303","0x050301017803016101016303016202","0x07030101780003016101016303016202","0x050301017803016202016101016303","0x07030101780003016202016101016303","0x050301017803016202016303016101","0x07030101780003016202016303016101","0x050301017803016303016101016202","0x07030101780003016303016101016202","0x050301017803016303016202016101","0x07030101780003016303016202016101"]} +{"tensor":{"type":"tensor<int8>(x{},y{})","cells":[]},"binary":["0x0503020178017900","0x070302017801790000"]} +{"tensor":{"type":"tensor<int8>(x{},y{})","cells":[{"address":{"x":"bar","y":"a"},"value":21},{"address":{"x":"foo","y":"a"},"value":11}]},"binary":["0x050302017801790203666F6F01610B03626172016115","0x07030201780179000203666F6F01610B03626172016115","0x05030201780179020362617201611503666F6F01610B","0x0703020178017900020362617201611503666F6F01610B"]} +{"tensor":{"type":"tensor<int8>(x{},y[10])","cells":[]},"binary":["0x07030101780101790A00"]} +{"tensor":{"type":"tensor<int8>(x[10],y{})","cells":[]},"binary":["0x07030101790101780A00"]} +{"tensor":{"type":"tensor<int8>(x{},y[3])","cells":[{"address":{"x":"a","y":0},"value":11},{"address":{"x":"a","y":1},"value":12},{"address":{"x":"a","y":2},"value":13},{"address":{"x":"b","y":0},"value":21},{"address":{"x":"b","y":1},"value":22},{"address":{"x":"b","y":2},"value":23}]},"binary":["0x0703010178010179030201610B0C0D0162151617","0x07030101780101790302016215161701610B0C0D"]} +{"tensor":{"type":"tensor<int8>(x[3],y{})","cells":[{"address":{"x":0,"y":"a"},"value":11},{"address":{"x":0,"y":"b"},"value":21},{"address":{"x":1,"y":"a"},"value":12},{"address":{"x":1,"y":"b"},"value":22},{"address":{"x":2,"y":"a"},"value":13},{"address":{"x":2,"y":"b"},"value":23}]},"binary":["0x0703010179010178030201610B0C0D0162151617","0x07030101790101780302016215161701610B0C0D"]} +{"num_tests":42} diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index e51569da988..9ad2c55f7e3 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1411,7 +1411,9 @@ ], "fields": [ "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE", - "public static final enum com.yahoo.tensor.TensorType$Value FLOAT" + "public static final enum com.yahoo.tensor.TensorType$Value FLOAT", + "public static final enum com.yahoo.tensor.TensorType$Value INT8", + "public static final enum com.yahoo.tensor.TensorType$Value BFLOAT16" ] }, "com.yahoo.tensor.TensorType": { @@ -3463,4 +3465,4 @@ ], "fields": [] } -} +}
\ No newline at end of file diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index dc17c657db9..9f3d7c01c6b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -225,6 +225,10 @@ public abstract class IndexedTensor implements Tensor { b.append(tensor.get(index)); else if (tensor.type().valueType() == TensorType.Value.FLOAT) b.append(tensor.getFloat(index)); + else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) + b.append(tensor.getFloat(index)); + else if (tensor.type().valueType() == TensorType.Value.INT8) + b.append(tensor.getFloat(index)); else throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); @@ -295,6 +299,10 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + else if (type.valueType() == TensorType.Value.BFLOAT16) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + else if (type.valueType() == TensorType.Value.INT8) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); else @@ -315,6 +323,10 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + else if (type.valueType() == TensorType.Value.BFLOAT16) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.INT8) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); else @@ -335,6 +347,10 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.BFLOAT16) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.INT8) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); else diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index f608aead347..be5f4143f54 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -533,6 +533,10 @@ public class MixedTensor implements Tensor { b.append(getDouble(subspaceIndex, index, tensor)); else if (tensor.type().valueType() == TensorType.Value.FLOAT) b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats + else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) + b.append(getDouble(subspaceIndex, index, tensor)); + else if (tensor.type().valueType() == TensorType.Value.INT8) + b.append(getDouble(subspaceIndex, index, tensor)); else throw new IllegalStateException("Unexpected value type " + type.valueType()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index becec1a4493..126fee878ab 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -201,6 +201,10 @@ class TensorParser { return Double.parseDouble(cellValueString); else if (cellValueType == TensorType.Value.FLOAT) return Float.parseFloat(cellValueString); + else if (cellValueType == TensorType.Value.BFLOAT16) + return Float.parseFloat(cellValueString); + else if (cellValueType == TensorType.Value.INT8) + return Float.parseFloat(cellValueString); else throw new IllegalArgumentException(cellValueType + " is not supported"); } catch (NumberFormatException e) { @@ -291,6 +295,10 @@ class TensorParser { builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); else if (builder.type().valueType() == TensorType.Value.FLOAT) builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); + else if (builder.type().valueType() == TensorType.Value.BFLOAT16) + builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); + else if (builder.type().valueType() == TensorType.Value.INT8) + builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); } } @@ -355,6 +363,10 @@ class TensorParser { builder.cell(address, (Double)number); else if (builder.type().valueType() == TensorType.Value.FLOAT) builder.cell(address, (Float)number); + else if (builder.type().valueType() == TensorType.Value.BFLOAT16) + builder.cell(address, (Float)number); + else if (builder.type().valueType() == TensorType.Value.INT8) + builder.cell(address, (Float)number); } } @@ -392,6 +404,10 @@ class TensorParser { builder.cell(address, Double.parseDouble(cellValueString)); else if (cellValueType == TensorType.Value.FLOAT) builder.cell(address, Float.parseFloat(cellValueString)); + else if (cellValueType == TensorType.Value.BFLOAT16) + builder.cell(address, Float.parseFloat(cellValueString)); + else if (cellValueType == TensorType.Value.INT8) + builder.cell(address, Float.parseFloat(cellValueString)); else throw new IllegalArgumentException(cellValueType + " is not supported"); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 236e9d31c39..d7cf5bffcfe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -33,7 +33,7 @@ public class TensorType { public enum Value { // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below - DOUBLE("double"), FLOAT("float"); + DOUBLE("double"), FLOAT("float"), INT8("int8"), BFLOAT16("bfloat16"); private final String id; @@ -59,6 +59,9 @@ public class TensorType { public static Value largestOf(Value value1, Value value2) { if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE; + if (value1 == FLOAT || value2 == FLOAT) return FLOAT; + if (value1 == BFLOAT16 || value2 == BFLOAT16) return FLOAT; + if (value1 == INT8 || value2 == INT8) return FLOAT; return FLOAT; } @@ -69,8 +72,10 @@ public class TensorType { switch (valueTypeString) { case "double" : return Value.DOUBLE; case "float" : return Value.FLOAT; - default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + - " but was '" + valueTypeString + "'"); + case "bfloat16" : return Value.BFLOAT16; + case "int8" : return Value.INT8; + default : throw new IllegalArgumentException("Value type must be either 'double', 'float', " + + "'bfloat16', or 'int8' but was '" + valueTypeString + "'"); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 0cec09157fb..edb68025d45 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -53,6 +53,8 @@ public class DenseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: encodeDoubleCells(tensor, buffer); break; case FLOAT: encodeFloatCells(tensor, buffer); break; + case BFLOAT16: encodeBFloat16Cells(tensor, buffer); break; + case INT8: encodeInt8Cells(tensor, buffer); break; } } @@ -66,6 +68,16 @@ public class DenseBinaryFormat implements BinaryFormat { buffer.putFloat(tensor.getFloat(i)); } + private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.putShort((short)(Float.floatToRawIntBits(tensor.getFloat(i)) >>> 16)); + } + + private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.put((byte) tensor.getFloat(i)); + } + @Override public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) { TensorType type; @@ -111,6 +123,8 @@ public class DenseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break; case FLOAT: decodeFloatCells(sizes, builder, buffer); break; + case BFLOAT16: decodeBFloat16Cells(sizes, builder, buffer); break; + case INT8: decodeInt8Cells(sizes, builder, buffer); break; } } @@ -124,4 +138,16 @@ public class DenseBinaryFormat implements BinaryFormat { builder.cellByDirectIndex(i, buffer.getFloat()); } + private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) { + builder.cellByDirectIndex(i, Float.intBitsToFloat(buffer.getShort() << 16)); + } + } + + private void decodeInt8Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) { + builder.cellByDirectIndex(i, (float) buffer.get()); + } + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index bc247e5561f..7d500614caa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -64,6 +64,9 @@ class MixedBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; + case BFLOAT16: encodeCells(buffer, tensor, (val) -> + buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break; + case INT8: encodeCells(buffer, tensor, (val) -> buffer.put(((byte)val.floatValue()))); break; } } @@ -127,6 +130,9 @@ class MixedBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; + case BFLOAT16: decodeCells(buffer, builder, type, () -> + (double)Float.intBitsToFloat(buffer.getShort() << 16)); break; + case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index cd671f824fa..160cf660e1b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -52,6 +52,9 @@ class SparseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; + case BFLOAT16: encodeCells(buffer, tensor, (val) -> + buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break; + case INT8: encodeCells(buffer, tensor, (val) -> buffer.put((byte)(val.floatValue()))); break; } } @@ -102,6 +105,9 @@ class SparseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; + case BFLOAT16: decodeCells(buffer, builder, type, () -> + (double)Float.intBitsToFloat(buffer.getShort() << 16)); break; + case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 5c47572c779..cddb283489c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -29,6 +29,8 @@ public class TypedBinaryFormat { private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing private static final int FLOAT_VALUE_TYPE = 1; + private static final int BFLOAT16_VALUE_TYPE = 2; + private static final int INT8_VALUE_TYPE = 3; public static byte[] encode(Tensor tensor) { GrowableByteBuffer buffer = new GrowableByteBuffer(); @@ -113,6 +115,8 @@ public class TypedBinaryFormat { switch (valueType) { case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break; case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break; + case BFLOAT16: buffer.putInt1_4Bytes(BFLOAT16_VALUE_TYPE); break; + case INT8: buffer.putInt1_4Bytes(INT8_VALUE_TYPE); break; default: throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType); } @@ -123,8 +127,11 @@ public class TypedBinaryFormat { switch (valueType) { case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE; case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT; + case BFLOAT16_VALUE_TYPE: return TensorType.Value.BFLOAT16; + case INT8_VALUE_TYPE: return TensorType.Value.INT8; } - throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal."); + throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. " + + "Only 0(double), 1(float), 2(bfloat16), or 3(int8) is legal."); } private static byte[] asByteArray(GrowableByteBuffer buffer) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 5bd1bbdba37..b47c0873535 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -50,6 +50,35 @@ public class TensorTestCase { assertEquals(Tensor.from("tensor<float>(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class); assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[1])")).cell(5.0, 0).build().getClass(), IndexedFloatTensor.class); + + assertEquals(Tensor.from("tensor<bfloat16>(x[1]):[5]").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + + assertEquals(Tensor.from("tensor<int8>(x[1]):[5]").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + } + + private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) { + Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }"); + Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }"); + assertEquals(valueType, t1.multiply(t2).type().valueType()); + assertEquals(valueType, t2.multiply(t1).type().valueType()); + } + + @Test + public void testValueTypeResolving() { + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "float"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16"); + assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8"); } @Test diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index a547f941d8e..caa125dfef7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -96,8 +96,12 @@ public class TensorTypeTestCase { assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])"); assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])"); + assertValueType(TensorType.Value.BFLOAT16, "tensor<bfloat16>(x[])"); + assertValueType(TensorType.Value.INT8, "tensor<int8>(x[])"); assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString()); assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString()); + assertEquals("tensor<bfloat16>(x[])", TensorType.fromSpec("tensor<bfloat16>(x[])").toString()); + assertEquals("tensor<int8>(x[])", TensorType.fromSpec("tensor<int8>(x[])").toString()); } private static void assertTensorType(String typeSpec) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 5d1bc7b0c3f..3c79b0c769c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -41,7 +41,7 @@ public class DenseBinaryFormatTestCase { } @Test - public void requireThatDefaultSerializationFormatDoNotChange() { + public void requireThatDefaultSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[]{2, // binary format type 2, // dimension count 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size @@ -54,7 +54,7 @@ public class DenseBinaryFormatTestCase { } @Test - public void requireThatFloatSerializationFormatDoNotChange() { + public void requireThatFloatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[]{6, // binary format type 1, // float type 2, // dimension count @@ -68,9 +68,44 @@ public class DenseBinaryFormatTestCase { } @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[]{6, // binary format type + 2, // bfloat16 type + 2, // dimension count + 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size + 1, (byte) 'z', 1, // dimension z with size + 64, 0, // value 1 + 64, 64, // value 2 + }; + Tensor tensor = Tensor.from("tensor<bfloat16>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[]{6, // binary format type + 3, // int8 type + 2, // dimension count + 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size + 1, (byte) 'z', 1, // dimension z with size + 2, // value 1 + 3, // value 2 + }; + Tensor tensor = Tensor.from("tensor<int8>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test public void testSerializationOfDifferentValueTypes() { + assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<bfloat16>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<int8>(x[],y[]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); + assertSerialization("tensor<double>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]"); + assertSerialization("tensor<float>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]"); + assertSerialization("tensor<bfloat16>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]"); + assertSerialization("tensor<int8>(x[2],y[2]):[2, 3, 4, 5]"); } private void assertSerialization(String tensorString) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 81de8a9db4c..3ca20661587 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -134,4 +134,19 @@ public class JsonFormatTestCase { } } + private void assertEncodeDecode(Tensor tensor) { + Tensor decoded = JsonFormat.decode(tensor.type(), JsonFormat.encodeWithType(tensor)); + assertEquals(tensor, decoded); + assertEquals(tensor.type(), decoded.type()); + } + + @Test + public void testTensorCellTypes() { + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor<double>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor<float>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor<bfloat16>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor<int8>(x[2],y[2]):[2,3,5,8]")); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java index 69ef4922d8d..e9f8c81f21b 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; +import java.util.Arrays; import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -78,9 +79,70 @@ public class MixedBinaryFormatTestCase { } @Test + public void requireThatDefaultSerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {3, // binary format type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0 + 2, (byte)'c', (byte)'d', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatFloatSerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {7, // binary format type + 1, // float type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 64, 0, 0, 0, // cell 0 + 2, (byte)'c', (byte)'d', 64, 64, 0, 0}; // cell 1 + Tensor tensor = Tensor.from("tensor<float>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {7, // binary format type + 2, // bfloat16 type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 64, 0, // cell 0 + 2, (byte)'c', (byte)'d', 64, 64}; // cell 1 + Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {7, // binary format type + 3, // int8 type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 2, // cell 0 + 2, (byte)'c', (byte)'d', 3}; // cell 1 + Tensor tensor = Tensor.from("tensor<int8>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test public void testSerializationOfDifferentValueTypes() { assertSerialization("tensor<double>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor<float>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<bfloat16>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<int8>(x{},y[2]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); } private void assertSerialization(String tensorString) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 50b71024ddf..2a622b73513 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -55,19 +55,19 @@ public class SparseBinaryFormatTestCase { } @Test - public void requireThatSerializationFormatDoNotChange() { + public void requireThatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[] {1, // binary format type 2, // num dimensions 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions 2, // num cells, 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1 - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test - public void requireThatFloatSerializationFormatDoNotChange() { + public void requireThatFloatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[] { 5, // binary format type 1, // float type @@ -76,14 +76,44 @@ public class SparseBinaryFormatTestCase { 2, // num cells, 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1 - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Tensor tensor = Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] { + 5, // binary format type + 2, // bfloat16 type + 2, // num dimensions + 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions + 2, // num cells, + 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, // cell 0 + 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64}; // cell 1 + Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] { + 5, // binary format type + 3, // int8 type + 2, // num dimensions + 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions + 2, // num cells, + 2, (byte)'a', (byte)'b', 1, (byte)'e', 2, // cell 0 + 2, (byte)'c', (byte)'d', 1, (byte)'e', 3}; // cell 1 + Tensor tensor = Tensor.from("tensor<int8>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test public void testSerializationOfDifferentValueTypes() { assertSerialization("tensor<double>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor<float>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<bfloat16>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<int8>(x{},y{}):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); } private void assertSerialization(String tensorString) { |