From 8ae580ad9d7a2b7bd2420c5f543a21ac0d7a7c9a Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 8 Apr 2021 11:32:30 +0200 Subject: Add bfloat16 and int8 to Java cell_cast function --- .../src/main/java/com/yahoo/tensor/functions/CellCast.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'vespajlib') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index d052e383c85..d853c4a9069 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -64,10 +64,14 @@ public class CellCast extends PrimitiveTensorFunction i = tensor.cellIterator(); i.hasNext(); ) { Tensor.Cell cell = i.next(); - if (fromValueType == TensorType.Value.FLOAT) { - builder.cell(cell.getKey(), cell.getFloatValue()); - } else if (fromValueType == TensorType.Value.DOUBLE) { + if (fromValueType == TensorType.Value.DOUBLE) { builder.cell(cell.getKey(), cell.getDoubleValue()); + } else if (fromValueType == TensorType.Value.FLOAT) { + builder.cell(cell.getKey(), cell.getFloatValue()); + } else if (fromValueType == TensorType.Value.BFLOAT16) { + builder.cell(cell.getKey(), cell.getFloatValue()); + } else if (fromValueType == TensorType.Value.INT8) { + builder.cell(cell.getKey(), cell.getFloatValue()); } else { builder.cell(cell.getKey(), cell.getValue()); } -- cgit v1.2.3