diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java | 46 |
1 files changed, 38 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index b6ea0d04a50..c6f8171bd18 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -11,6 +11,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Objects; +import java.util.function.Function; /** * The <i>cell_cast</i> tensor function creates a new tensor with the specified cell value type. @@ -62,20 +63,49 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM private Tensor cast(Tensor tensor, TensorType type) { Tensor.Builder builder = Tensor.Builder.of(type); TensorType.Value fromValueType = tensor.type().valueType(); + switch (fromValueType) { + case DOUBLE: + return castFromDouble(tensor, type); + case FLOAT: + case BFLOAT16: + case INT8: + return castFromSomeFloat(tensor, type); + default: + throw new IllegalStateException("Unexpected value type " + fromValueType); + } + } + + private Tensor castFromDouble(Tensor tensor, TensorType type) { + Tensor.Builder builder = Tensor.Builder.of(type); + var restrict = selectRestrict(type.valueType()); + for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + builder.cell(cell.getKey(), restrict.apply((float)cell.getDoubleValue())); + } + return builder.build(); + } + + private Tensor castFromSomeFloat(Tensor tensor, TensorType type) { + Tensor.Builder builder = Tensor.Builder.of(type); + var restrict = selectRestrict(type.valueType()); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Tensor.Cell cell = i.next(); - switch (fromValueType) { - case DOUBLE: builder.cell(cell.getKey(), cell.getDoubleValue()); break; - case FLOAT: builder.cell(cell.getKey(), cell.getFloatValue()); break; - case BFLOAT16: builder.cell(cell.getKey(), cell.getFloatValue()); break; - case INT8: builder.cell(cell.getKey(), cell.getFloatValue()); break; - default: - builder.cell(cell.getKey(), cell.getValue()); - } + builder.cell(cell.getKey(), restrict.apply(cell.getFloatValue())); } return builder.build(); } + static private Function<Float,Float> selectRestrict(TensorType.Value toValueType) { + switch (toValueType) { + case BFLOAT16: + return val -> Float.intBitsToFloat(Float.floatToRawIntBits(val) & ~0xffff); + case INT8: + return val -> (float)val.byteValue(); + default: + return val -> val; + } + } + @Override public String toString(ToStringContext context) { return "cell_cast(" + argument.toString(context) + ", " + valueType + ")"; |