diff options
author | Lester Solbakken <lesters@oath.com> | 2021-04-08 15:05:01 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-04-08 15:05:01 +0200 |
commit | cb1ea8c336adb05c90200468b25fe4ab89ee803c (patch) | |
tree | 1d0b5da3bf31ec8be03c38818ad9a08d592de120 | |
parent | 8ae580ad9d7a2b7bd2420c5f543a21ac0d7a7c9a (diff) |
Resolve feedback from PR review
10 files changed, 98 insertions, 126 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 9f3d7c01c6b..c369fe96562 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -221,16 +221,14 @@ public abstract class IndexedTensor implements Tensor { b.append("["); // value - if (tensor.type().valueType() == TensorType.Value.DOUBLE) - b.append(tensor.get(index)); - else if (tensor.type().valueType() == TensorType.Value.FLOAT) - b.append(tensor.getFloat(index)); - else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) - b.append(tensor.getFloat(index)); - else if (tensor.type().valueType() == TensorType.Value.INT8) - b.append(tensor.getFloat(index)); - else - throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); + switch (tensor.type().valueType()) { + case DOUBLE: b.append(tensor.get(index)); break; + case FLOAT: b.append(tensor.getFloat(index)); break; + case BFLOAT16: b.append(tensor.getFloat(index)); break; + case INT8: b.append(tensor.getFloat(index)); break; + default: + throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); + } // end bracket and comma for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) @@ -296,17 +294,14 @@ public abstract class IndexedTensor implements Tensor { */ public static Builder of(TensorType type, DimensionSizes sizes) { validate(type, sizes); - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.BFLOAT16) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.INT8) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + } } /** @@ -320,17 +315,14 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, float[] values) { validate(type, sizes); validateSizes(sizes, values.length); - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - else if (type.valueType() == TensorType.Value.BFLOAT16) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.INT8) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + } } /** @@ -344,17 +336,15 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, double[] values) { validate(type, sizes); validateSizes(sizes, values.length); + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.BFLOAT16) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.INT8) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default + } } private static void validateSizes(DimensionSizes sizes, int length) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index be5f4143f54..606509bbfd8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -529,16 +529,14 @@ public class MixedTensor implements Tensor { b.append("["); // value - if (type.valueType() == TensorType.Value.DOUBLE) - b.append(getDouble(subspaceIndex, index, tensor)); - else if (tensor.type().valueType() == TensorType.Value.FLOAT) - b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats - else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) - b.append(getDouble(subspaceIndex, index, tensor)); - else if (tensor.type().valueType() == TensorType.Value.INT8) - b.append(getDouble(subspaceIndex, index, tensor)); - else - throw new IllegalStateException("Unexpected value type " + type.valueType()); + switch (type.valueType()) { + case DOUBLE: b.append(getDouble(subspaceIndex, index, tensor)); break; + case FLOAT: b.append(getDouble(subspaceIndex, index, tensor)); break; // TODO: Really use floats + case BFLOAT16: b.append(getDouble(subspaceIndex, index, tensor)); break; + case INT8: b.append(getDouble(subspaceIndex, index, tensor)); break; + default: + throw new IllegalStateException("Unexpected value type " + type.valueType()); + } // end bracket and comma for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 126fee878ab..0a1d9b6cf6e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -197,16 +197,14 @@ class TensorParser { try { String cellValueString = string.substring(position, nextNumberEnd); try { - if (cellValueType == TensorType.Value.DOUBLE) - return Double.parseDouble(cellValueString); - else if (cellValueType == TensorType.Value.FLOAT) - return Float.parseFloat(cellValueString); - else if (cellValueType == TensorType.Value.BFLOAT16) - return Float.parseFloat(cellValueString); - else if (cellValueType == TensorType.Value.INT8) - return Float.parseFloat(cellValueString); - else - throw new IllegalArgumentException(cellValueType + " is not supported"); + switch (cellValueType) { + case DOUBLE: return Double.parseDouble(cellValueString); + case FLOAT: return Float.parseFloat(cellValueString); + case BFLOAT16: return Float.parseFloat(cellValueString); + case INT8: return Float.parseFloat(cellValueString); + default: + throw new IllegalArgumentException(cellValueType + " is not supported"); + } } catch (NumberFormatException e) { throw new IllegalArgumentException("At value position " + position + ": '" + cellValueString + "' is not a valid " + cellValueType); @@ -291,16 +289,13 @@ class TensorParser { protected void consumeNumber() { Number number = consumeNumber(builder.type().valueType()); - if (builder.type().valueType() == TensorType.Value.DOUBLE) - builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); - else if (builder.type().valueType() == TensorType.Value.FLOAT) - builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); - else if (builder.type().valueType() == TensorType.Value.BFLOAT16) - builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); - else if (builder.type().valueType() == TensorType.Value.INT8) - builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); + switch (builder.type().valueType()) { + case DOUBLE: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); break; + case FLOAT: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break; + case BFLOAT16: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break; + case INT8: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break; + } } - } /** @@ -359,16 +354,13 @@ class TensorParser { private void consumeNumber(TensorAddress address) { Number number = consumeNumber(builder.type().valueType()); - if (builder.type().valueType() == TensorType.Value.DOUBLE) - builder.cell(address, (Double)number); - else if (builder.type().valueType() == TensorType.Value.FLOAT) - builder.cell(address, (Float)number); - else if (builder.type().valueType() == TensorType.Value.BFLOAT16) - builder.cell(address, (Float)number); - else if (builder.type().valueType() == TensorType.Value.INT8) - builder.cell(address, (Float)number); + switch (builder.type().valueType()) { + case DOUBLE: builder.cell(address, (Double)number); break; + case FLOAT: builder.cell(address, (Float)number); break; + case BFLOAT16: builder.cell(address, (Float)number); break; + case INT8: builder.cell(address, (Float)number); break; + } } - } private static class MappedValueParser extends ValueParser { @@ -400,16 +392,14 @@ class TensorParser { TensorType.Value cellValueType = builder.type().valueType(); String cellValueString = string.substring(position, valueEnd).trim(); try { - if (cellValueType == TensorType.Value.DOUBLE) - builder.cell(address, Double.parseDouble(cellValueString)); - else if (cellValueType == TensorType.Value.FLOAT) - builder.cell(address, Float.parseFloat(cellValueString)); - else if (cellValueType == TensorType.Value.BFLOAT16) - builder.cell(address, Float.parseFloat(cellValueString)); - else if (cellValueType == TensorType.Value.INT8) - builder.cell(address, Float.parseFloat(cellValueString)); - else - throw new IllegalArgumentException(cellValueType + " is not supported"); + switch (cellValueType) { + case DOUBLE: builder.cell(address, Double.parseDouble(cellValueString)); break; + case FLOAT: builder.cell(address, Float.parseFloat(cellValueString)); break; + case BFLOAT16: builder.cell(address, Float.parseFloat(cellValueString)); break; + case INT8: builder.cell(address, Float.parseFloat(cellValueString)); break; + default: + throw new IllegalArgumentException(cellValueType + " is not supported"); + } } catch (NumberFormatException e) { throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index d7cf5bffcfe..0f67c25337b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -60,23 +60,21 @@ public class TensorType { public static Value largestOf(Value value1, Value value2) { if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE; if (value1 == FLOAT || value2 == FLOAT) return FLOAT; - if (value1 == BFLOAT16 || value2 == BFLOAT16) return FLOAT; - if (value1 == INT8 || value2 == INT8) return FLOAT; - return FLOAT; + if (value1 == BFLOAT16 || value2 == BFLOAT16) return BFLOAT16; + return INT8; } @Override public String toString() { return name().toLowerCase(); } public static Value fromId(String valueTypeString) { - switch (valueTypeString) { - case "double" : return Value.DOUBLE; - case "float" : return Value.FLOAT; - case "bfloat16" : return Value.BFLOAT16; - case "int8" : return Value.INT8; - default : throw new IllegalArgumentException("Value type must be either 'double', 'float', " + - "'bfloat16', or 'int8' but was '" + valueTypeString + "'"); + for(Value value : Value.values()) { + if (value.id.equals(valueTypeString)) { + return value; + } } + throw new IllegalArgumentException("Value type must be either 'double', 'float', " + + "'bfloat16', or 'int8' but was '" + valueTypeString + "'"); } }; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index d853c4a9069..b6ea0d04a50 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -64,16 +64,13 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM TensorType.Value fromValueType = tensor.type().valueType(); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Tensor.Cell cell = i.next(); - if (fromValueType == TensorType.Value.DOUBLE) { - builder.cell(cell.getKey(), cell.getDoubleValue()); - } else if (fromValueType == TensorType.Value.FLOAT) { - builder.cell(cell.getKey(), cell.getFloatValue()); - } else if (fromValueType == TensorType.Value.BFLOAT16) { - builder.cell(cell.getKey(), cell.getFloatValue()); - } else if (fromValueType == TensorType.Value.INT8) { - builder.cell(cell.getKey(), cell.getFloatValue()); - } else { - builder.cell(cell.getKey(), cell.getValue()); + switch (fromValueType) { + case DOUBLE: builder.cell(cell.getKey(), cell.getDoubleValue()); break; + case FLOAT: builder.cell(cell.getKey(), cell.getFloatValue()); break; + case BFLOAT16: builder.cell(cell.getKey(), cell.getFloatValue()); break; + case INT8: builder.cell(cell.getKey(), cell.getFloatValue()); break; + default: + builder.cell(cell.getKey(), cell.getValue()); } } return builder.build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index edb68025d45..1567c95c9fa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -7,10 +7,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.Iterator; import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Supplier; /** * Implementation of a dense binary format for a tensor on the form: @@ -70,7 +67,7 @@ public class DenseBinaryFormat implements BinaryFormat { private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { for (int i = 0; i < tensor.size(); i++) - buffer.putShort((short)(Float.floatToRawIntBits(tensor.getFloat(i)) >>> 16)); + buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i))); } private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { @@ -140,7 +137,7 @@ public class DenseBinaryFormat implements BinaryFormat { private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) { for (long i = 0; i < sizes.totalSize(); i++) { - builder.cellByDirectIndex(i, Float.intBitsToFloat(buffer.getShort() << 16)); + builder.cellByDirectIndex(i, TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index 7d500614caa..6cb9a63fe68 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -65,7 +65,7 @@ class MixedBinaryFormat implements BinaryFormat { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; case BFLOAT16: encodeCells(buffer, tensor, (val) -> - buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break; + buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(val.floatValue()))); break; case INT8: encodeCells(buffer, tensor, (val) -> buffer.put(((byte)val.floatValue()))); break; } } @@ -131,7 +131,7 @@ class MixedBinaryFormat implements BinaryFormat { case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; case BFLOAT16: decodeCells(buffer, builder, type, () -> - (double)Float.intBitsToFloat(buffer.getShort() << 16)); break; + (double)TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); break; case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 160cf660e1b..763b722a90c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -53,7 +53,7 @@ class SparseBinaryFormat implements BinaryFormat { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; case BFLOAT16: encodeCells(buffer, tensor, (val) -> - buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break; + buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(val.floatValue()))); break; case INT8: encodeCells(buffer, tensor, (val) -> buffer.put((byte)(val.floatValue()))); break; } } @@ -106,7 +106,7 @@ class SparseBinaryFormat implements BinaryFormat { case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; case BFLOAT16: decodeCells(buffer, builder, type, () -> - (double)Float.intBitsToFloat(buffer.getShort() << 16)); break; + (double)TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); break; case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index cddb283489c..be04be80ed9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -141,4 +141,12 @@ public class TypedBinaryFormat { return result; } + static short bFloat16BitsFromFloat(float val) { + return (short) (Float.floatToRawIntBits(val) >>> 16); + } + + static float floatFromBFloat16Bits(short bits) { + return Float.intBitsToFloat(bits << 16); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index b47c0873535..572dc433d71 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -71,14 +71,8 @@ public class TensorTestCase { public void testValueTypeResolving() { assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double"); assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float"); - assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16"); - assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8"); assertCellTypeResult(TensorType.Value.FLOAT, "float", "float"); - assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16"); - assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8"); - assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16"); - assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8"); - assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8"); + // Test bfloat16 and int8 when we have proper cell type resolving in place. } @Test |