diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-04-09 06:45:16 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-04-09 10:31:36 +0000 |
commit | 527eada361b1000fc28cca04a7234845c2df839c (patch) | |
tree | c125891e20484af3e6ed4a4585760b9b481f104a /vespajlib | |
parent | cc6d1c271efffd4a3321478884fb82c7e6141091 (diff) |
restrict values to fit into target cell type
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java | 46 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java | 11 |
2 files changed, 49 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index b6ea0d04a50..c6f8171bd18 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -11,6 +11,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Objects; +import java.util.function.Function; /** * The <i>cell_cast</i> tensor function creates a new tensor with the specified cell value type. @@ -62,20 +63,49 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM private Tensor cast(Tensor tensor, TensorType type) { Tensor.Builder builder = Tensor.Builder.of(type); TensorType.Value fromValueType = tensor.type().valueType(); + switch (fromValueType) { + case DOUBLE: + return castFromDouble(tensor, type); + case FLOAT: + case BFLOAT16: + case INT8: + return castFromSomeFloat(tensor, type); + default: + throw new IllegalStateException("Unexpected value type " + fromValueType); + } + } + + private Tensor castFromDouble(Tensor tensor, TensorType type) { + Tensor.Builder builder = Tensor.Builder.of(type); + var restrict = selectRestrict(type.valueType()); + for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + builder.cell(cell.getKey(), restrict.apply((float)cell.getDoubleValue())); + } + return builder.build(); + } + + private Tensor castFromSomeFloat(Tensor tensor, TensorType type) { + Tensor.Builder builder = Tensor.Builder.of(type); + var restrict = selectRestrict(type.valueType()); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Tensor.Cell cell = i.next(); - switch (fromValueType) { - case DOUBLE: builder.cell(cell.getKey(), cell.getDoubleValue()); break; - case FLOAT: builder.cell(cell.getKey(), cell.getFloatValue()); break; - case BFLOAT16: builder.cell(cell.getKey(), cell.getFloatValue()); break; - case INT8: builder.cell(cell.getKey(), cell.getFloatValue()); break; - default: - builder.cell(cell.getKey(), cell.getValue()); - } + builder.cell(cell.getKey(), restrict.apply(cell.getFloatValue())); } return builder.build(); } + static private Function<Float,Float> selectRestrict(TensorType.Value toValueType) { + switch (toValueType) { + case BFLOAT16: + return val -> Float.intBitsToFloat(Float.floatToRawIntBits(val) & ~0xffff); + case INT8: + return val -> (float)val.byteValue(); + default: + return val -> val; + } + } + @Override public String toString(ToStringContext context) { return "cell_cast(" + argument.toString(context) + ", " + valueType + ")"; diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java index bc10ecc3abd..c2957df4ac1 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java @@ -33,6 +33,17 @@ public class CellCastTestCase { assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType()); assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType()); assertEquals(tensor, tensor.cellCast(TensorType.Value.DOUBLE)); + + tensor = Tensor.from("tensor<double>(x{}):{{x:0}:2.25,{x:1}:1.00000000001,{x:2}:256.0,{x:3}:1.00390625}"); + var asFloat = Tensor.from("tensor<float>(x{}):{{x:0}:2.25,{x:1}:1.0,{x:2}:256.0,{x:3}:1.00390625}"); + var asBFloat16 = Tensor.from("tensor<bfloat16>(x{}):{{x:0}:2.25,{x:1}:1.0,{x:2}:256.0,{x:3}:1.0}"); + var asInt8 = Tensor.from("tensor<int8>(x{}):{{x:0}:2,{x:1}:1,{x:2}:0,{x:3}:1}"); + assertEquals(asFloat, tensor.cellCast(TensorType.Value.FLOAT)); + assertEquals(asBFloat16, tensor.cellCast(TensorType.Value.BFLOAT16)); + assertEquals(asInt8, tensor.cellCast(TensorType.Value.INT8)); + assertEquals(asBFloat16, asFloat.cellCast(TensorType.Value.BFLOAT16)); + assertEquals(asInt8, asFloat.cellCast(TensorType.Value.INT8)); + assertEquals(asInt8, asBFloat16.cellCast(TensorType.Value.INT8)); } } |