summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-09 06:45:16 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-09 10:31:36 +0000
commit527eada361b1000fc28cca04a7234845c2df839c (patch)
treec125891e20484af3e6ed4a4585760b9b481f104a /vespajlib/src/main/java/com/yahoo/tensor
parentcc6d1c271efffd4a3321478884fb82c7e6141091 (diff)
restrict values to fit into target cell type
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java46
1 files changed, 38 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index b6ea0d04a50..c6f8171bd18 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -11,6 +11,7 @@ import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
+import java.util.function.Function;
/**
* The <i>cell_cast</i> tensor function creates a new tensor with the specified cell value type.
@@ -62,20 +63,49 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
private Tensor cast(Tensor tensor, TensorType type) {
Tensor.Builder builder = Tensor.Builder.of(type);
TensorType.Value fromValueType = tensor.type().valueType();
+ switch (fromValueType) {
+ case DOUBLE:
+ return castFromDouble(tensor, type);
+ case FLOAT:
+ case BFLOAT16:
+ case INT8:
+ return castFromSomeFloat(tensor, type);
+ default:
+ throw new IllegalStateException("Unexpected value type " + fromValueType);
+ }
+ }
+
+ private Tensor castFromDouble(Tensor tensor, TensorType type) {
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ var restrict = selectRestrict(type.valueType());
+ for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ builder.cell(cell.getKey(), restrict.apply((float)cell.getDoubleValue()));
+ }
+ return builder.build();
+ }
+
+ private Tensor castFromSomeFloat(Tensor tensor, TensorType type) {
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ var restrict = selectRestrict(type.valueType());
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Tensor.Cell cell = i.next();
- switch (fromValueType) {
- case DOUBLE: builder.cell(cell.getKey(), cell.getDoubleValue()); break;
- case FLOAT: builder.cell(cell.getKey(), cell.getFloatValue()); break;
- case BFLOAT16: builder.cell(cell.getKey(), cell.getFloatValue()); break;
- case INT8: builder.cell(cell.getKey(), cell.getFloatValue()); break;
- default:
- builder.cell(cell.getKey(), cell.getValue());
- }
+ builder.cell(cell.getKey(), restrict.apply(cell.getFloatValue()));
}
return builder.build();
}
+ static private Function<Float,Float> selectRestrict(TensorType.Value toValueType) {
+ switch (toValueType) {
+ case BFLOAT16:
+ return val -> Float.intBitsToFloat(Float.floatToRawIntBits(val) & ~0xffff);
+ case INT8:
+ return val -> (float)val.byteValue();
+ default:
+ return val -> val;
+ }
+ }
+
@Override
public String toString(ToStringContext context) {
return "cell_cast(" + argument.toString(context) + ", " + valueType + ")";