summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-09 06:45:16 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-09 10:31:36 +0000
commit527eada361b1000fc28cca04a7234845c2df839c (patch)
treec125891e20484af3e6ed4a4585760b9b481f104a /vespajlib
parentcc6d1c271efffd4a3321478884fb82c7e6141091 (diff)
restrict values to fit into target cell type
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java46
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java11
2 files changed, 49 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index b6ea0d04a50..c6f8171bd18 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -11,6 +11,7 @@ import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
+import java.util.function.Function;
/**
* The <i>cell_cast</i> tensor function creates a new tensor with the specified cell value type.
@@ -62,20 +63,49 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
private Tensor cast(Tensor tensor, TensorType type) {
Tensor.Builder builder = Tensor.Builder.of(type);
TensorType.Value fromValueType = tensor.type().valueType();
+ switch (fromValueType) {
+ case DOUBLE:
+ return castFromDouble(tensor, type);
+ case FLOAT:
+ case BFLOAT16:
+ case INT8:
+ return castFromSomeFloat(tensor, type);
+ default:
+ throw new IllegalStateException("Unexpected value type " + fromValueType);
+ }
+ }
+
+ private Tensor castFromDouble(Tensor tensor, TensorType type) {
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ var restrict = selectRestrict(type.valueType());
+ for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ builder.cell(cell.getKey(), restrict.apply((float)cell.getDoubleValue()));
+ }
+ return builder.build();
+ }
+
+ private Tensor castFromSomeFloat(Tensor tensor, TensorType type) {
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ var restrict = selectRestrict(type.valueType());
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Tensor.Cell cell = i.next();
- switch (fromValueType) {
- case DOUBLE: builder.cell(cell.getKey(), cell.getDoubleValue()); break;
- case FLOAT: builder.cell(cell.getKey(), cell.getFloatValue()); break;
- case BFLOAT16: builder.cell(cell.getKey(), cell.getFloatValue()); break;
- case INT8: builder.cell(cell.getKey(), cell.getFloatValue()); break;
- default:
- builder.cell(cell.getKey(), cell.getValue());
- }
+ builder.cell(cell.getKey(), restrict.apply(cell.getFloatValue()));
}
return builder.build();
}
+ static private Function<Float,Float> selectRestrict(TensorType.Value toValueType) {
+ switch (toValueType) {
+ case BFLOAT16:
+ return val -> Float.intBitsToFloat(Float.floatToRawIntBits(val) & ~0xffff);
+ case INT8:
+ return val -> (float)val.byteValue();
+ default:
+ return val -> val;
+ }
+ }
+
@Override
public String toString(ToStringContext context) {
return "cell_cast(" + argument.toString(context) + ", " + valueType + ")";
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java
index bc10ecc3abd..c2957df4ac1 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java
@@ -33,6 +33,17 @@ public class CellCastTestCase {
assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
assertEquals(tensor, tensor.cellCast(TensorType.Value.DOUBLE));
+
+ tensor = Tensor.from("tensor<double>(x{}):{{x:0}:2.25,{x:1}:1.00000000001,{x:2}:256.0,{x:3}:1.00390625}");
+ var asFloat = Tensor.from("tensor<float>(x{}):{{x:0}:2.25,{x:1}:1.0,{x:2}:256.0,{x:3}:1.00390625}");
+ var asBFloat16 = Tensor.from("tensor<bfloat16>(x{}):{{x:0}:2.25,{x:1}:1.0,{x:2}:256.0,{x:3}:1.0}");
+ var asInt8 = Tensor.from("tensor<int8>(x{}):{{x:0}:2,{x:1}:1,{x:2}:0,{x:3}:1}");
+ assertEquals(asFloat, tensor.cellCast(TensorType.Value.FLOAT));
+ assertEquals(asBFloat16, tensor.cellCast(TensorType.Value.BFLOAT16));
+ assertEquals(asInt8, tensor.cellCast(TensorType.Value.INT8));
+ assertEquals(asBFloat16, asFloat.cellCast(TensorType.Value.BFLOAT16));
+ assertEquals(asInt8, asFloat.cellCast(TensorType.Value.INT8));
+ assertEquals(asInt8, asBFloat16.cellCast(TensorType.Value.INT8));
}
}