summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-08 11:32:30 +0200
committerLester Solbakken <lesters@oath.com>2021-04-08 11:32:30 +0200
commit8ae580ad9d7a2b7bd2420c5f543a21ac0d7a7c9a (patch)
treea334de5eaf933998ccdf18c45aeedf73180ff377 /vespajlib
parent049e9a325c8142958909d0464da12a56e5a8f638 (diff)
Add bfloat16 and int8 to Java cell_cast function
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java10
1 files changed, 7 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index d052e383c85..d853c4a9069 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -64,10 +64,14 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
TensorType.Value fromValueType = tensor.type().valueType();
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Tensor.Cell cell = i.next();
- if (fromValueType == TensorType.Value.FLOAT) {
- builder.cell(cell.getKey(), cell.getFloatValue());
- } else if (fromValueType == TensorType.Value.DOUBLE) {
+ if (fromValueType == TensorType.Value.DOUBLE) {
builder.cell(cell.getKey(), cell.getDoubleValue());
+ } else if (fromValueType == TensorType.Value.FLOAT) {
+ builder.cell(cell.getKey(), cell.getFloatValue());
+ } else if (fromValueType == TensorType.Value.BFLOAT16) {
+ builder.cell(cell.getKey(), cell.getFloatValue());
+ } else if (fromValueType == TensorType.Value.INT8) {
+ builder.cell(cell.getKey(), cell.getFloatValue());
} else {
builder.cell(cell.getKey(), cell.getValue());
}