From 049e9a325c8142958909d0464da12a56e5a8f638 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 8 Apr 2021 11:24:52 +0200 Subject: Add bfloat16 and int8 tensor cell types in Java --- .../test_spec.json | 22 +++++++- vespajlib/abi-spec.json | 6 ++- .../main/java/com/yahoo/tensor/IndexedTensor.java | 16 ++++++ .../main/java/com/yahoo/tensor/MixedTensor.java | 4 ++ .../main/java/com/yahoo/tensor/TensorParser.java | 16 ++++++ .../src/main/java/com/yahoo/tensor/TensorType.java | 11 ++-- .../tensor/serialization/DenseBinaryFormat.java | 26 +++++++++ .../tensor/serialization/MixedBinaryFormat.java | 6 +++ .../tensor/serialization/SparseBinaryFormat.java | 6 +++ .../tensor/serialization/TypedBinaryFormat.java | 9 +++- .../test/java/com/yahoo/tensor/TensorTestCase.java | 29 ++++++++++ .../java/com/yahoo/tensor/TensorTypeTestCase.java | 4 ++ .../serialization/DenseBinaryFormatTestCase.java | 39 +++++++++++++- .../tensor/serialization/JsonFormatTestCase.java | 15 ++++++ .../serialization/MixedBinaryFormatTestCase.java | 62 ++++++++++++++++++++++ .../serialization/SparseBinaryFormatTestCase.java | 42 ++++++++++++--- 16 files changed, 298 insertions(+), 15 deletions(-) diff --git a/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json b/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json index f6b535e071a..b7710eadf5d 100644 --- a/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json +++ b/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json @@ -20,4 +20,24 @@ {"tensor":{"type":"tensor(x[10],y{})","cells":[]},"binary":["0x07010101790101780A00"]} {"tensor":{"type":"tensor(x{},y[3])","cells":[{"address":{"x":"a","y":0},"value":11},{"address":{"x":"a","y":1},"value":12},{"address":{"x":"a","y":2},"value":13},{"address":{"x":"b","y":0},"value":21},{"address":{"x":"b","y":1},"value":22},{"address":{"x":"b","y":2},"value":23}]},"binary":["0x070101017801017903020161413000004140000041500000016241A8000041B0000041B80000","0x07010101780101790302016241A8000041B0000041B800000161413000004140000041500000"]} {"tensor":{"type":"tensor(x[3],y{})","cells":[{"address":{"x":0,"y":"a"},"value":11},{"address":{"x":0,"y":"b"},"value":21},{"address":{"x":1,"y":"a"},"value":12},{"address":{"x":1,"y":"b"},"value":22},{"address":{"x":2,"y":"a"},"value":13},{"address":{"x":2,"y":"b"},"value":23}]},"binary":["0x070101017901017803020161413000004140000041500000016241A8000041B0000041B80000","0x07010101790101780302016241A8000041B0000041B800000161413000004140000041500000"]} -{"num_tests":22} +{"tensor":{"type":"tensor(x[3])","cells":[{"address":{"x":0},"value":1},{"address":{"x":1},"value":2},{"address":{"x":2},"value":3}]},"binary":["0x0602010178033F8040004040","0x070200010178033F8040004040"]} +{"tensor":{"type":"tensor(x[2],y[3])","cells":[{"address":{"x":0,"y":0},"value":11},{"address":{"x":0,"y":1},"value":12},{"address":{"x":0,"y":2},"value":13},{"address":{"x":1,"y":0},"value":21},{"address":{"x":1,"y":1},"value":22},{"address":{"x":1,"y":2},"value":23}]},"binary":["0x06020201780201790341304140415041A841B041B8","0x0702000201780201790341304140415041A841B041B8"]} +{"tensor":{"type":"tensor(x{})","cells":[]},"binary":["0x050201017800","0x07020101780000"]} +{"tensor":{"type":"tensor(x{})","cells":[{"address":{"x":"a"},"value":1},{"address":{"x":"b"},"value":2},{"address":{"x":"c"},"value":3}]},"binary":["0x05020101780301613F800162400001634040","0x0702010178000301613F800162400001634040","0x05020101780301613F800163404001624000","0x0702010178000301613F800163404001624000","0x0502010178030162400001613F8001634040","0x070201017800030162400001613F8001634040","0x050201017803016240000163404001613F80","0x07020101780003016240000163404001613F80","0x0502010178030163404001613F8001624000","0x070201017800030163404001613F8001624000","0x050201017803016340400162400001613F80","0x07020101780003016340400162400001613F80"]} +{"tensor":{"type":"tensor(x{},y{})","cells":[]},"binary":["0x0502020178017900","0x070202017801790000"]} +{"tensor":{"type":"tensor(x{},y{})","cells":[{"address":{"x":"bar","y":"a"},"value":21},{"address":{"x":"foo","y":"a"},"value":11}]},"binary":["0x050202017801790203666F6F0161413003626172016141A8","0x07020201780179000203666F6F0161413003626172016141A8","0x050202017801790203626172016141A803666F6F01614130","0x07020201780179000203626172016141A803666F6F01614130"]} +{"tensor":{"type":"tensor(x{},y[10])","cells":[]},"binary":["0x07020101780101790A00"]} +{"tensor":{"type":"tensor(x[10],y{})","cells":[]},"binary":["0x07020101790101780A00"]} +{"tensor":{"type":"tensor(x{},y[3])","cells":[{"address":{"x":"a","y":0},"value":11},{"address":{"x":"a","y":1},"value":12},{"address":{"x":"a","y":2},"value":13},{"address":{"x":"b","y":0},"value":21},{"address":{"x":"b","y":1},"value":22},{"address":{"x":"b","y":2},"value":23}]},"binary":["0x070201017801017903020161413041404150016241A841B041B8","0x07020101780101790302016241A841B041B80161413041404150"]} +{"tensor":{"type":"tensor(x[3],y{})","cells":[{"address":{"x":0,"y":"a"},"value":11},{"address":{"x":0,"y":"b"},"value":21},{"address":{"x":1,"y":"a"},"value":12},{"address":{"x":1,"y":"b"},"value":22},{"address":{"x":2,"y":"a"},"value":13},{"address":{"x":2,"y":"b"},"value":23}]},"binary":["0x070201017901017803020161413041404150016241A841B041B8","0x07020101790101780302016241A841B041B80161413041404150"]} +{"tensor":{"type":"tensor(x[3])","cells":[{"address":{"x":0},"value":1},{"address":{"x":1},"value":2},{"address":{"x":2},"value":3}]},"binary":["0x060301017803010203","0x07030001017803010203"]} +{"tensor":{"type":"tensor(x[2],y[3])","cells":[{"address":{"x":0,"y":0},"value":11},{"address":{"x":0,"y":1},"value":12},{"address":{"x":0,"y":2},"value":13},{"address":{"x":1,"y":0},"value":21},{"address":{"x":1,"y":1},"value":22},{"address":{"x":1,"y":2},"value":23}]},"binary":["0x0603020178020179030B0C0D151617","0x070300020178020179030B0C0D151617"]} +{"tensor":{"type":"tensor(x{})","cells":[]},"binary":["0x050301017800","0x07030101780000"]} +{"tensor":{"type":"tensor(x{})","cells":[{"address":{"x":"a"},"value":1},{"address":{"x":"b"},"value":2},{"address":{"x":"c"},"value":3}]},"binary":["0x050301017803016101016202016303","0x07030101780003016101016202016303","0x050301017803016101016303016202","0x07030101780003016101016303016202","0x050301017803016202016101016303","0x07030101780003016202016101016303","0x050301017803016202016303016101","0x07030101780003016202016303016101","0x050301017803016303016101016202","0x07030101780003016303016101016202","0x050301017803016303016202016101","0x07030101780003016303016202016101"]} +{"tensor":{"type":"tensor(x{},y{})","cells":[]},"binary":["0x0503020178017900","0x070302017801790000"]} +{"tensor":{"type":"tensor(x{},y{})","cells":[{"address":{"x":"bar","y":"a"},"value":21},{"address":{"x":"foo","y":"a"},"value":11}]},"binary":["0x050302017801790203666F6F01610B03626172016115","0x07030201780179000203666F6F01610B03626172016115","0x05030201780179020362617201611503666F6F01610B","0x0703020178017900020362617201611503666F6F01610B"]} +{"tensor":{"type":"tensor(x{},y[10])","cells":[]},"binary":["0x07030101780101790A00"]} +{"tensor":{"type":"tensor(x[10],y{})","cells":[]},"binary":["0x07030101790101780A00"]} +{"tensor":{"type":"tensor(x{},y[3])","cells":[{"address":{"x":"a","y":0},"value":11},{"address":{"x":"a","y":1},"value":12},{"address":{"x":"a","y":2},"value":13},{"address":{"x":"b","y":0},"value":21},{"address":{"x":"b","y":1},"value":22},{"address":{"x":"b","y":2},"value":23}]},"binary":["0x0703010178010179030201610B0C0D0162151617","0x07030101780101790302016215161701610B0C0D"]} +{"tensor":{"type":"tensor(x[3],y{})","cells":[{"address":{"x":0,"y":"a"},"value":11},{"address":{"x":0,"y":"b"},"value":21},{"address":{"x":1,"y":"a"},"value":12},{"address":{"x":1,"y":"b"},"value":22},{"address":{"x":2,"y":"a"},"value":13},{"address":{"x":2,"y":"b"},"value":23}]},"binary":["0x0703010179010178030201610B0C0D0162151617","0x07030101790101780302016215161701610B0C0D"]} +{"num_tests":42} diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index e51569da988..9ad2c55f7e3 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1411,7 +1411,9 @@ ], "fields": [ "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE", - "public static final enum com.yahoo.tensor.TensorType$Value FLOAT" + "public static final enum com.yahoo.tensor.TensorType$Value FLOAT", + "public static final enum com.yahoo.tensor.TensorType$Value INT8", + "public static final enum com.yahoo.tensor.TensorType$Value BFLOAT16" ] }, "com.yahoo.tensor.TensorType": { @@ -3463,4 +3465,4 @@ ], "fields": [] } -} +} \ No newline at end of file diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index dc17c657db9..9f3d7c01c6b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -225,6 +225,10 @@ public abstract class IndexedTensor implements Tensor { b.append(tensor.get(index)); else if (tensor.type().valueType() == TensorType.Value.FLOAT) b.append(tensor.getFloat(index)); + else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) + b.append(tensor.getFloat(index)); + else if (tensor.type().valueType() == TensorType.Value.INT8) + b.append(tensor.getFloat(index)); else throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); @@ -295,6 +299,10 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + else if (type.valueType() == TensorType.Value.BFLOAT16) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + else if (type.valueType() == TensorType.Value.INT8) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); else @@ -315,6 +323,10 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + else if (type.valueType() == TensorType.Value.BFLOAT16) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.INT8) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); else @@ -335,6 +347,10 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.BFLOAT16) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.INT8) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); else diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index f608aead347..be5f4143f54 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -533,6 +533,10 @@ public class MixedTensor implements Tensor { b.append(getDouble(subspaceIndex, index, tensor)); else if (tensor.type().valueType() == TensorType.Value.FLOAT) b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats + else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) + b.append(getDouble(subspaceIndex, index, tensor)); + else if (tensor.type().valueType() == TensorType.Value.INT8) + b.append(getDouble(subspaceIndex, index, tensor)); else throw new IllegalStateException("Unexpected value type " + type.valueType()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index becec1a4493..126fee878ab 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -201,6 +201,10 @@ class TensorParser { return Double.parseDouble(cellValueString); else if (cellValueType == TensorType.Value.FLOAT) return Float.parseFloat(cellValueString); + else if (cellValueType == TensorType.Value.BFLOAT16) + return Float.parseFloat(cellValueString); + else if (cellValueType == TensorType.Value.INT8) + return Float.parseFloat(cellValueString); else throw new IllegalArgumentException(cellValueType + " is not supported"); } catch (NumberFormatException e) { @@ -291,6 +295,10 @@ class TensorParser { builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); else if (builder.type().valueType() == TensorType.Value.FLOAT) builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); + else if (builder.type().valueType() == TensorType.Value.BFLOAT16) + builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); + else if (builder.type().valueType() == TensorType.Value.INT8) + builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); } } @@ -355,6 +363,10 @@ class TensorParser { builder.cell(address, (Double)number); else if (builder.type().valueType() == TensorType.Value.FLOAT) builder.cell(address, (Float)number); + else if (builder.type().valueType() == TensorType.Value.BFLOAT16) + builder.cell(address, (Float)number); + else if (builder.type().valueType() == TensorType.Value.INT8) + builder.cell(address, (Float)number); } } @@ -392,6 +404,10 @@ class TensorParser { builder.cell(address, Double.parseDouble(cellValueString)); else if (cellValueType == TensorType.Value.FLOAT) builder.cell(address, Float.parseFloat(cellValueString)); + else if (cellValueType == TensorType.Value.BFLOAT16) + builder.cell(address, Float.parseFloat(cellValueString)); + else if (cellValueType == TensorType.Value.INT8) + builder.cell(address, Float.parseFloat(cellValueString)); else throw new IllegalArgumentException(cellValueType + " is not supported"); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 236e9d31c39..d7cf5bffcfe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -33,7 +33,7 @@ public class TensorType { public enum Value { // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below - DOUBLE("double"), FLOAT("float"); + DOUBLE("double"), FLOAT("float"), INT8("int8"), BFLOAT16("bfloat16"); private final String id; @@ -59,6 +59,9 @@ public class TensorType { public static Value largestOf(Value value1, Value value2) { if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE; + if (value1 == FLOAT || value2 == FLOAT) return FLOAT; + if (value1 == BFLOAT16 || value2 == BFLOAT16) return FLOAT; + if (value1 == INT8 || value2 == INT8) return FLOAT; return FLOAT; } @@ -69,8 +72,10 @@ public class TensorType { switch (valueTypeString) { case "double" : return Value.DOUBLE; case "float" : return Value.FLOAT; - default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + - " but was '" + valueTypeString + "'"); + case "bfloat16" : return Value.BFLOAT16; + case "int8" : return Value.INT8; + default : throw new IllegalArgumentException("Value type must be either 'double', 'float', " + + "'bfloat16', or 'int8' but was '" + valueTypeString + "'"); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 0cec09157fb..edb68025d45 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -53,6 +53,8 @@ public class DenseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: encodeDoubleCells(tensor, buffer); break; case FLOAT: encodeFloatCells(tensor, buffer); break; + case BFLOAT16: encodeBFloat16Cells(tensor, buffer); break; + case INT8: encodeInt8Cells(tensor, buffer); break; } } @@ -66,6 +68,16 @@ public class DenseBinaryFormat implements BinaryFormat { buffer.putFloat(tensor.getFloat(i)); } + private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.putShort((short)(Float.floatToRawIntBits(tensor.getFloat(i)) >>> 16)); + } + + private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { + for (int i = 0; i < tensor.size(); i++) + buffer.put((byte) tensor.getFloat(i)); + } + @Override public Tensor decode(Optional optionalType, GrowableByteBuffer buffer) { TensorType type; @@ -111,6 +123,8 @@ public class DenseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break; case FLOAT: decodeFloatCells(sizes, builder, buffer); break; + case BFLOAT16: decodeBFloat16Cells(sizes, builder, buffer); break; + case INT8: decodeInt8Cells(sizes, builder, buffer); break; } } @@ -124,4 +138,16 @@ public class DenseBinaryFormat implements BinaryFormat { builder.cellByDirectIndex(i, buffer.getFloat()); } + private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) { + builder.cellByDirectIndex(i, Float.intBitsToFloat(buffer.getShort() << 16)); + } + } + + private void decodeInt8Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { + for (long i = 0; i < sizes.totalSize(); i++) { + builder.cellByDirectIndex(i, (float) buffer.get()); + } + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index bc247e5561f..7d500614caa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -64,6 +64,9 @@ class MixedBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; + case BFLOAT16: encodeCells(buffer, tensor, (val) -> + buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break; + case INT8: encodeCells(buffer, tensor, (val) -> buffer.put(((byte)val.floatValue()))); break; } } @@ -127,6 +130,9 @@ class MixedBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; + case BFLOAT16: decodeCells(buffer, builder, type, () -> + (double)Float.intBitsToFloat(buffer.getShort() << 16)); break; + case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index cd671f824fa..160cf660e1b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -52,6 +52,9 @@ class SparseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; + case BFLOAT16: encodeCells(buffer, tensor, (val) -> + buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break; + case INT8: encodeCells(buffer, tensor, (val) -> buffer.put((byte)(val.floatValue()))); break; } } @@ -102,6 +105,9 @@ class SparseBinaryFormat implements BinaryFormat { switch (serializationValueType) { case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; + case BFLOAT16: decodeCells(buffer, builder, type, () -> + (double)Float.intBitsToFloat(buffer.getShort() << 16)); break; + case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 5c47572c779..cddb283489c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -29,6 +29,8 @@ public class TypedBinaryFormat { private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing private static final int FLOAT_VALUE_TYPE = 1; + private static final int BFLOAT16_VALUE_TYPE = 2; + private static final int INT8_VALUE_TYPE = 3; public static byte[] encode(Tensor tensor) { GrowableByteBuffer buffer = new GrowableByteBuffer(); @@ -113,6 +115,8 @@ public class TypedBinaryFormat { switch (valueType) { case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break; case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break; + case BFLOAT16: buffer.putInt1_4Bytes(BFLOAT16_VALUE_TYPE); break; + case INT8: buffer.putInt1_4Bytes(INT8_VALUE_TYPE); break; default: throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType); } @@ -123,8 +127,11 @@ public class TypedBinaryFormat { switch (valueType) { case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE; case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT; + case BFLOAT16_VALUE_TYPE: return TensorType.Value.BFLOAT16; + case INT8_VALUE_TYPE: return TensorType.Value.INT8; } - throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal."); + throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. " + + "Only 0(double), 1(float), 2(bfloat16), or 3(int8) is legal."); } private static byte[] asByteArray(GrowableByteBuffer buffer) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 5bd1bbdba37..b47c0873535 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -50,6 +50,35 @@ public class TensorTestCase { assertEquals(Tensor.from("tensor(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class); assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(5.0, 0).build().getClass(), IndexedFloatTensor.class); + + assertEquals(Tensor.from("tensor(x[1]):[5]").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + + assertEquals(Tensor.from("tensor(x[1]):[5]").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + } + + private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) { + Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }"); + Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }"); + assertEquals(valueType, t1.multiply(t2).type().valueType()); + assertEquals(valueType, t2.multiply(t1).type().valueType()); + } + + @Test + public void testValueTypeResolving() { + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16"); + assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "float"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16"); + assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16"); + assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8"); + assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8"); } @Test diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index a547f941d8e..caa125dfef7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -96,8 +96,12 @@ public class TensorTypeTestCase { assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); assertValueType(TensorType.Value.FLOAT, "tensor(x[])"); + assertValueType(TensorType.Value.BFLOAT16, "tensor(x[])"); + assertValueType(TensorType.Value.INT8, "tensor(x[])"); assertEquals("tensor(x[])", TensorType.fromSpec("tensor(x[])").toString()); assertEquals("tensor(x[])", TensorType.fromSpec("tensor(x[])").toString()); + assertEquals("tensor(x[])", TensorType.fromSpec("tensor(x[])").toString()); + assertEquals("tensor(x[])", TensorType.fromSpec("tensor(x[])").toString()); } private static void assertTensorType(String typeSpec) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 5d1bc7b0c3f..3c79b0c769c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -41,7 +41,7 @@ public class DenseBinaryFormatTestCase { } @Test - public void requireThatDefaultSerializationFormatDoNotChange() { + public void requireThatDefaultSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[]{2, // binary format type 2, // dimension count 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size @@ -54,7 +54,7 @@ public class DenseBinaryFormatTestCase { } @Test - public void requireThatFloatSerializationFormatDoNotChange() { + public void requireThatFloatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[]{6, // binary format type 1, // float type 2, // dimension count @@ -67,10 +67,45 @@ public class DenseBinaryFormatTestCase { assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } + @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[]{6, // binary format type + 2, // bfloat16 type + 2, // dimension count + 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size + 1, (byte) 'z', 1, // dimension z with size + 64, 0, // value 1 + 64, 64, // value 2 + }; + Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[]{6, // binary format type + 3, // int8 type + 2, // dimension count + 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size + 1, (byte) 'z', 1, // dimension z with size + 2, // value 1 + 3, // value 2 + }; + Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + @Test public void testSerializationOfDifferentValueTypes() { + assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); + assertSerialization("tensor(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]"); + assertSerialization("tensor(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]"); + assertSerialization("tensor(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]"); + assertSerialization("tensor(x[2],y[2]):[2, 3, 4, 5]"); } private void assertSerialization(String tensorString) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 81de8a9db4c..3ca20661587 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -134,4 +134,19 @@ public class JsonFormatTestCase { } } + private void assertEncodeDecode(Tensor tensor) { + Tensor decoded = JsonFormat.decode(tensor.type(), JsonFormat.encodeWithType(tensor)); + assertEquals(tensor, decoded); + assertEquals(tensor.type(), decoded.type()); + } + + @Test + public void testTensorCellTypes() { + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]")); + assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2,3,5,8]")); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java index 69ef4922d8d..e9f8c81f21b 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; +import java.util.Arrays; import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -77,10 +78,71 @@ public class MixedBinaryFormatTestCase { assertSerialization(tensor); } + @Test + public void requireThatDefaultSerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {3, // binary format type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0 + 2, (byte)'c', (byte)'d', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatFloatSerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {7, // binary format type + 1, // float type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 64, 0, 0, 0, // cell 0 + 2, (byte)'c', (byte)'d', 64, 64, 0, 0}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {7, // binary format type + 2, // bfloat16 type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 64, 0, // cell 0 + 2, (byte)'c', (byte)'d', 64, 64}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] {7, // binary format type + 3, // int8 type + 1, // number of sparse dimensions + 2, (byte)'x', (byte)'y', // name of sparse dimension + 1, // number of dense dimensions + 1, (byte)'z', 1, // name and size of dense dimension + 2, // num cells, + 2, (byte)'a', (byte)'b', 2, // cell 0 + 2, (byte)'c', (byte)'d', 3}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + @Test public void testSerializationOfDifferentValueTypes() { assertSerialization("tensor(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x{},y[2]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); } private void assertSerialization(String tensorString) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 50b71024ddf..2a622b73513 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -55,19 +55,19 @@ public class SparseBinaryFormatTestCase { } @Test - public void requireThatSerializationFormatDoNotChange() { + public void requireThatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[] {1, // binary format type 2, // num dimensions 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions 2, // num cells, 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1 - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test - public void requireThatFloatSerializationFormatDoNotChange() { + public void requireThatFloatSerializationFormatDoesNotChange() { byte[] encodedTensor = new byte[] { 5, // binary format type 1, // float type @@ -76,14 +76,44 @@ public class SparseBinaryFormatTestCase { 2, // num cells, 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1 - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatBFloat16SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] { + 5, // binary format type + 2, // bfloat16 type + 2, // num dimensions + 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions + 2, // num cells, + 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, // cell 0 + 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); + } + + @Test + public void requireThatInt8SerializationFormatDoesNotChange() { + byte[] encodedTensor = new byte[] { + 5, // binary format type + 3, // int8 type + 2, // num dimensions + 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions + 2, // num cells, + 2, (byte)'a', (byte)'b', 1, (byte)'e', 2, // cell 0 + 2, (byte)'c', (byte)'d', 1, (byte)'e', 3}; // cell 1 + Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test public void testSerializationOfDifferentValueTypes() { assertSerialization("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor(x{},y{}):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}"); } private void assertSerialization(String tensorString) { -- cgit v1.2.3 From 8ae580ad9d7a2b7bd2420c5f543a21ac0d7a7c9a Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 8 Apr 2021 11:32:30 +0200 Subject: Add bfloat16 and int8 to Java cell_cast function --- .../src/main/java/com/yahoo/tensor/functions/CellCast.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index d052e383c85..d853c4a9069 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -64,10 +64,14 @@ public class CellCast extends PrimitiveTensorFunction i = tensor.cellIterator(); i.hasNext(); ) { Tensor.Cell cell = i.next(); - if (fromValueType == TensorType.Value.FLOAT) { - builder.cell(cell.getKey(), cell.getFloatValue()); - } else if (fromValueType == TensorType.Value.DOUBLE) { + if (fromValueType == TensorType.Value.DOUBLE) { builder.cell(cell.getKey(), cell.getDoubleValue()); + } else if (fromValueType == TensorType.Value.FLOAT) { + builder.cell(cell.getKey(), cell.getFloatValue()); + } else if (fromValueType == TensorType.Value.BFLOAT16) { + builder.cell(cell.getKey(), cell.getFloatValue()); + } else if (fromValueType == TensorType.Value.INT8) { + builder.cell(cell.getKey(), cell.getFloatValue()); } else { builder.cell(cell.getKey(), cell.getValue()); } -- cgit v1.2.3 From cb1ea8c336adb05c90200468b25fe4ab89ee803c Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 8 Apr 2021 15:05:01 +0200 Subject: Resolve feedback from PR review --- .../main/java/com/yahoo/tensor/IndexedTensor.java | 74 ++++++++++------------ .../main/java/com/yahoo/tensor/MixedTensor.java | 18 +++--- .../main/java/com/yahoo/tensor/TensorParser.java | 66 ++++++++----------- .../src/main/java/com/yahoo/tensor/TensorType.java | 18 +++--- .../java/com/yahoo/tensor/functions/CellCast.java | 17 ++--- .../tensor/serialization/DenseBinaryFormat.java | 7 +- .../tensor/serialization/MixedBinaryFormat.java | 4 +- .../tensor/serialization/SparseBinaryFormat.java | 4 +- .../tensor/serialization/TypedBinaryFormat.java | 8 +++ .../test/java/com/yahoo/tensor/TensorTestCase.java | 8 +-- 10 files changed, 98 insertions(+), 126 deletions(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 9f3d7c01c6b..c369fe96562 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -221,16 +221,14 @@ public abstract class IndexedTensor implements Tensor { b.append("["); // value - if (tensor.type().valueType() == TensorType.Value.DOUBLE) - b.append(tensor.get(index)); - else if (tensor.type().valueType() == TensorType.Value.FLOAT) - b.append(tensor.getFloat(index)); - else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) - b.append(tensor.getFloat(index)); - else if (tensor.type().valueType() == TensorType.Value.INT8) - b.append(tensor.getFloat(index)); - else - throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); + switch (tensor.type().valueType()) { + case DOUBLE: b.append(tensor.get(index)); break; + case FLOAT: b.append(tensor.getFloat(index)); break; + case BFLOAT16: b.append(tensor.getFloat(index)); break; + case INT8: b.append(tensor.getFloat(index)); break; + default: + throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); + } // end bracket and comma for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) @@ -296,17 +294,14 @@ public abstract class IndexedTensor implements Tensor { */ public static Builder of(TensorType type, DimensionSizes sizes) { validate(type, sizes); - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.BFLOAT16) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.INT8) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + } } /** @@ -320,17 +315,14 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, float[] values) { validate(type, sizes); validateSizes(sizes, values.length); - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - else if (type.valueType() == TensorType.Value.BFLOAT16) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.INT8) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + } } /** @@ -344,17 +336,15 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, double[] values) { validate(type, sizes); validateSizes(sizes, values.length); + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.BFLOAT16) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.INT8) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default + } } private static void validateSizes(DimensionSizes sizes, int length) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index be5f4143f54..606509bbfd8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -529,16 +529,14 @@ public class MixedTensor implements Tensor { b.append("["); // value - if (type.valueType() == TensorType.Value.DOUBLE) - b.append(getDouble(subspaceIndex, index, tensor)); - else if (tensor.type().valueType() == TensorType.Value.FLOAT) - b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats - else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) - b.append(getDouble(subspaceIndex, index, tensor)); - else if (tensor.type().valueType() == TensorType.Value.INT8) - b.append(getDouble(subspaceIndex, index, tensor)); - else - throw new IllegalStateException("Unexpected value type " + type.valueType()); + switch (type.valueType()) { + case DOUBLE: b.append(getDouble(subspaceIndex, index, tensor)); break; + case FLOAT: b.append(getDouble(subspaceIndex, index, tensor)); break; // TODO: Really use floats + case BFLOAT16: b.append(getDouble(subspaceIndex, index, tensor)); break; + case INT8: b.append(getDouble(subspaceIndex, index, tensor)); break; + default: + throw new IllegalStateException("Unexpected value type " + type.valueType()); + } // end bracket and comma for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 126fee878ab..0a1d9b6cf6e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -197,16 +197,14 @@ class TensorParser { try { String cellValueString = string.substring(position, nextNumberEnd); try { - if (cellValueType == TensorType.Value.DOUBLE) - return Double.parseDouble(cellValueString); - else if (cellValueType == TensorType.Value.FLOAT) - return Float.parseFloat(cellValueString); - else if (cellValueType == TensorType.Value.BFLOAT16) - return Float.parseFloat(cellValueString); - else if (cellValueType == TensorType.Value.INT8) - return Float.parseFloat(cellValueString); - else - throw new IllegalArgumentException(cellValueType + " is not supported"); + switch (cellValueType) { + case DOUBLE: return Double.parseDouble(cellValueString); + case FLOAT: return Float.parseFloat(cellValueString); + case BFLOAT16: return Float.parseFloat(cellValueString); + case INT8: return Float.parseFloat(cellValueString); + default: + throw new IllegalArgumentException(cellValueType + " is not supported"); + } } catch (NumberFormatException e) { throw new IllegalArgumentException("At value position " + position + ": '" + cellValueString + "' is not a valid " + cellValueType); @@ -291,16 +289,13 @@ class TensorParser { protected void consumeNumber() { Number number = consumeNumber(builder.type().valueType()); - if (builder.type().valueType() == TensorType.Value.DOUBLE) - builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); - else if (builder.type().valueType() == TensorType.Value.FLOAT) - builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); - else if (builder.type().valueType() == TensorType.Value.BFLOAT16) - builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); - else if (builder.type().valueType() == TensorType.Value.INT8) - builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); + switch (builder.type().valueType()) { + case DOUBLE: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); break; + case FLOAT: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break; + case BFLOAT16: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break; + case INT8: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break; + } } - } /** @@ -359,16 +354,13 @@ class TensorParser { private void consumeNumber(TensorAddress address) { Number number = consumeNumber(builder.type().valueType()); - if (builder.type().valueType() == TensorType.Value.DOUBLE) - builder.cell(address, (Double)number); - else if (builder.type().valueType() == TensorType.Value.FLOAT) - builder.cell(address, (Float)number); - else if (builder.type().valueType() == TensorType.Value.BFLOAT16) - builder.cell(address, (Float)number); - else if (builder.type().valueType() == TensorType.Value.INT8) - builder.cell(address, (Float)number); + switch (builder.type().valueType()) { + case DOUBLE: builder.cell(address, (Double)number); break; + case FLOAT: builder.cell(address, (Float)number); break; + case BFLOAT16: builder.cell(address, (Float)number); break; + case INT8: builder.cell(address, (Float)number); break; + } } - } private static class MappedValueParser extends ValueParser { @@ -400,16 +392,14 @@ class TensorParser { TensorType.Value cellValueType = builder.type().valueType(); String cellValueString = string.substring(position, valueEnd).trim(); try { - if (cellValueType == TensorType.Value.DOUBLE) - builder.cell(address, Double.parseDouble(cellValueString)); - else if (cellValueType == TensorType.Value.FLOAT) - builder.cell(address, Float.parseFloat(cellValueString)); - else if (cellValueType == TensorType.Value.BFLOAT16) - builder.cell(address, Float.parseFloat(cellValueString)); - else if (cellValueType == TensorType.Value.INT8) - builder.cell(address, Float.parseFloat(cellValueString)); - else - throw new IllegalArgumentException(cellValueType + " is not supported"); + switch (cellValueType) { + case DOUBLE: builder.cell(address, Double.parseDouble(cellValueString)); break; + case FLOAT: builder.cell(address, Float.parseFloat(cellValueString)); break; + case BFLOAT16: builder.cell(address, Float.parseFloat(cellValueString)); break; + case INT8: builder.cell(address, Float.parseFloat(cellValueString)); break; + default: + throw new IllegalArgumentException(cellValueType + " is not supported"); + } } catch (NumberFormatException e) { throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index d7cf5bffcfe..0f67c25337b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -60,23 +60,21 @@ public class TensorType { public static Value largestOf(Value value1, Value value2) { if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE; if (value1 == FLOAT || value2 == FLOAT) return FLOAT; - if (value1 == BFLOAT16 || value2 == BFLOAT16) return FLOAT; - if (value1 == INT8 || value2 == INT8) return FLOAT; - return FLOAT; + if (value1 == BFLOAT16 || value2 == BFLOAT16) return BFLOAT16; + return INT8; } @Override public String toString() { return name().toLowerCase(); } public static Value fromId(String valueTypeString) { - switch (valueTypeString) { - case "double" : return Value.DOUBLE; - case "float" : return Value.FLOAT; - case "bfloat16" : return Value.BFLOAT16; - case "int8" : return Value.INT8; - default : throw new IllegalArgumentException("Value type must be either 'double', 'float', " + - "'bfloat16', or 'int8' but was '" + valueTypeString + "'"); + for(Value value : Value.values()) { + if (value.id.equals(valueTypeString)) { + return value; + } } + throw new IllegalArgumentException("Value type must be either 'double', 'float', " + + "'bfloat16', or 'int8' but was '" + valueTypeString + "'"); } }; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index d853c4a9069..b6ea0d04a50 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -64,16 +64,13 @@ public class CellCast extends PrimitiveTensorFunction i = tensor.cellIterator(); i.hasNext(); ) { Tensor.Cell cell = i.next(); - if (fromValueType == TensorType.Value.DOUBLE) { - builder.cell(cell.getKey(), cell.getDoubleValue()); - } else if (fromValueType == TensorType.Value.FLOAT) { - builder.cell(cell.getKey(), cell.getFloatValue()); - } else if (fromValueType == TensorType.Value.BFLOAT16) { - builder.cell(cell.getKey(), cell.getFloatValue()); - } else if (fromValueType == TensorType.Value.INT8) { - builder.cell(cell.getKey(), cell.getFloatValue()); - } else { - builder.cell(cell.getKey(), cell.getValue()); + switch (fromValueType) { + case DOUBLE: builder.cell(cell.getKey(), cell.getDoubleValue()); break; + case FLOAT: builder.cell(cell.getKey(), cell.getFloatValue()); break; + case BFLOAT16: builder.cell(cell.getKey(), cell.getFloatValue()); break; + case INT8: builder.cell(cell.getKey(), cell.getFloatValue()); break; + default: + builder.cell(cell.getKey(), cell.getValue()); } } return builder.build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index edb68025d45..1567c95c9fa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -7,10 +7,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.Iterator; import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Supplier; /** * Implementation of a dense binary format for a tensor on the form: @@ -70,7 +67,7 @@ public class DenseBinaryFormat implements BinaryFormat { private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { for (int i = 0; i < tensor.size(); i++) - buffer.putShort((short)(Float.floatToRawIntBits(tensor.getFloat(i)) >>> 16)); + buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i))); } private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { @@ -140,7 +137,7 @@ public class DenseBinaryFormat implements BinaryFormat { private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { for (long i = 0; i < sizes.totalSize(); i++) { - builder.cellByDirectIndex(i, Float.intBitsToFloat(buffer.getShort() << 16)); + builder.cellByDirectIndex(i, TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index 7d500614caa..6cb9a63fe68 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -65,7 +65,7 @@ class MixedBinaryFormat implements BinaryFormat { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; case BFLOAT16: encodeCells(buffer, tensor, (val) -> - buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break; + buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(val.floatValue()))); break; case INT8: encodeCells(buffer, tensor, (val) -> buffer.put(((byte)val.floatValue()))); break; } } @@ -131,7 +131,7 @@ class MixedBinaryFormat implements BinaryFormat { case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; case BFLOAT16: decodeCells(buffer, builder, type, () -> - (double)Float.intBitsToFloat(buffer.getShort() << 16)); break; + (double)TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); break; case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 160cf660e1b..763b722a90c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -53,7 +53,7 @@ class SparseBinaryFormat implements BinaryFormat { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; case BFLOAT16: encodeCells(buffer, tensor, (val) -> - buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break; + buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(val.floatValue()))); break; case INT8: encodeCells(buffer, tensor, (val) -> buffer.put((byte)(val.floatValue()))); break; } } @@ -106,7 +106,7 @@ class SparseBinaryFormat implements BinaryFormat { case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; case BFLOAT16: decodeCells(buffer, builder, type, () -> - (double)Float.intBitsToFloat(buffer.getShort() << 16)); break; + (double)TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); break; case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index cddb283489c..be04be80ed9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -141,4 +141,12 @@ public class TypedBinaryFormat { return result; } + static short bFloat16BitsFromFloat(float val) { + return (short) (Float.floatToRawIntBits(val) >>> 16); + } + + static float floatFromBFloat16Bits(short bits) { + return Float.intBitsToFloat(bits << 16); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index b47c0873535..572dc433d71 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -71,14 +71,8 @@ public class TensorTestCase { public void testValueTypeResolving() { assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double"); assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float"); - assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16"); - assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8"); assertCellTypeResult(TensorType.Value.FLOAT, "float", "float"); - assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16"); - assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8"); - assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16"); - assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8"); - assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8"); + // Test bfloat16 and int8 when we have proper cell type resolving in place. } @Test -- cgit v1.2.3 From 4a33700665782a9ac22522dc5a8f8138f07b5b73 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 8 Apr 2021 15:05:57 +0200 Subject: Fix C++ serialization test for new cell types --- .../make_tensor_binary_format_test_spec.cpp | 31 +++++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp b/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp index 6e882fc3d9d..974f95a2add 100644 --- a/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp +++ b/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include #include #include @@ -20,14 +22,20 @@ using Dict = std::vector; template std::vector with_cell_type_opts(); template <> std::vector with_cell_type_opts() { return {false, true}; } template <> std::vector with_cell_type_opts() { return {true}; } +template <> std::vector with_cell_type_opts() { return {true}; } +template <> std::vector with_cell_type_opts() { return {true}; } template uint8_t cell_type_id(); template <> uint8_t cell_type_id() { return 0; } template <> uint8_t cell_type_id() { return 1; } +template <> uint8_t cell_type_id() { return 2; } +template <> uint8_t cell_type_id() { return 3; } template const char *cell_type_str(); template <> const char *cell_type_str() { return ""; } template <> const char *cell_type_str() { return ""; } +template <> const char *cell_type_str() { return ""; } +template <> const char *cell_type_str() { return ""; } template nbostream make_sparse(bool with_cell_type) { nbostream data; @@ -62,7 +70,8 @@ template nbostream make_mixed(bool with_cell_type) { return data; } -void set_tensor(Cursor &test, const TensorSpec &spec) { +void set_tensor(Cursor &test, const TensorSpec &spec_in) { + auto spec = spec_in.normalize(); const Inspector &old_tensor = test["tensor"]; if (old_tensor.valid()) { TensorSpec old_spec = TensorSpec::from_slime(old_tensor); @@ -183,8 +192,8 @@ void make_vector_test(Cursor &test, size_t x_size) { for (size_t x = 0; x < x_size; ++x) { double value = val(x); spec.add({{"x", x}}, value); - dense << static_cast(value); - mixed << static_cast(value); + dense << T(value); + mixed << T(value); } set_tensor(test, spec); add_binary(test, {dense, mixed}); @@ -212,8 +221,8 @@ void make_matrix_test(Cursor &test, size_t x_size, size_t y_size) { for (size_t y = 0; y < y_size; ++y) { double value = mix({val(x), val(y)}); spec.add({{"x", x}, {"y", y}}, value); - dense << static_cast(value); - mixed << static_cast(value); + dense << T(value); + mixed << T(value); } } set_tensor(test, spec); @@ -245,8 +254,8 @@ void make_map_test(Cursor &test, const Dict &x_dict_in) { spec.add({{"x", x}}, value); sparse.writeSmallString(x); mixed.writeSmallString(x); - sparse << static_cast(value); - mixed << static_cast(value); + sparse << T(value); + mixed << T(value); } set_tensor(test, spec); add_binary(test, {sparse, mixed}); @@ -285,8 +294,8 @@ void make_mesh_test(Cursor &test, const Dict &x_dict_in, const vespalib::string sparse.writeSmallString(y); mixed.writeSmallString(x); mixed.writeSmallString(y); - sparse << static_cast(value); - mixed << static_cast(value); + sparse << T(value); + mixed << T(value); } set_tensor(test, spec); add_binary(test, {sparse, mixed}); @@ -326,7 +335,7 @@ void make_vector_map_test(Cursor &test, for (size_t idx = 0; idx < indexed_size; ++idx) { double value = mix({val(label), val(idx)}); spec.add({{mapped_name, label}, {indexed_name, idx}}, value); - mixed << static_cast(value); + mixed << T(value); } } set_tensor(test, spec); @@ -360,6 +369,8 @@ void make_tests(test::TestWriter &writer) { make_number_test(writer.create(), 42.0); make_typed_tests(writer); make_typed_tests(writer); + make_typed_tests(writer); + make_typed_tests(writer); } int main(int, char **) { -- cgit v1.2.3