diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-10-26 22:15:42 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-10-26 22:15:42 +0200 |
commit | f8da735cc1116d5036d20a3f767f4e3b69d2f72b (patch) | |
tree | 3e3caf09898ec2688b0eeacf40b09819068f1c69 /model-integration | |
parent | 14b16ae550784fed98910e04188429579e289f2d (diff) |
Add support and upgrade opset
Diffstat (limited to 'model-integration')
4 files changed, 31 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 952adc36621..2612702e99b 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import ai.onnxruntime.TensorInfo; import ai.onnxruntime.ValueInfo; +import ai.onnxruntime.platform.Fp16Conversions; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; @@ -56,7 +57,6 @@ class TensorConverter { throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions"); } IndexedTensor tensor = (IndexedTensor) vespaTensor; - ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder()); if (onnxTensorInfo.type == OnnxJavaType.FLOAT) { for (int i = 0; i < tensor.size(); i++) @@ -88,6 +88,17 @@ class TensorConverter { buffer.putLong((long) tensor.get(i)); return OnnxTensor.createTensor(environment, buffer.rewind().asLongBuffer(), tensor.shape()); } + if (onnxTensorInfo.type == OnnxJavaType.FLOAT16) { + for (int i = 0; i < tensor.size(); i++) { + buffer.putShort(Fp16Conversions.floatToFp16((float)tensor.get(i))); + } + return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape(), OnnxJavaType.FLOAT16); + } + if (onnxTensorInfo.type == OnnxJavaType.BFLOAT16) { + for (int i = 0; i < tensor.size(); i++) + buffer.putShort(Fp16Conversions.floatToBf16((float)tensor.get(i))); + return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape(), OnnxJavaType.BFLOAT16); + } throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type); } @@ -132,6 +143,16 @@ class TensorConverter { for (long i = 0; i < sizes.totalSize(); i++) builder.cellByDirectIndex(i, buffer.get()); } + else if (tensorInfo.type == OnnxJavaType.FLOAT16) { + ShortBuffer buffer = onnxTensor.getShortBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get())); + } + else if (tensorInfo.type == OnnxJavaType.BFLOAT16) { + ShortBuffer buffer = onnxTensor.getShortBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get()))); + } else { throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); } @@ -183,6 +204,7 @@ class TensorConverter { switch (onnxType) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return TensorType.Value.FLOAT; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; } diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index 9bb01fc8073..75da4d163bb 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -90,12 +90,14 @@ public class OnnxEvaluatorTest { var runtime = new OnnxRuntime(); assertEvaluate(runtime, "add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); assertEvaluate(runtime, "add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); + assertEvaluate(runtime, "add_float16.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); + //Add is not a supported operation for bfloat16 types in onnx operators. + assertEvaluate(runtime, "sign_bfloat16.onnx", "tensor<bfloat16>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]"); + assertEvaluate(runtime, "add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]"); assertEvaluate(runtime, "cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]"); assertEvaluate(runtime, "cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]"); - - // ONNX Runtime 1.8.0 does not support much of bfloat16 yet - // assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]"); + assertEvaluate(runtime,"cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]"); } @Test diff --git a/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx b/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx index cb19592abf4..9fcbd7f1b3c 100644 --- a/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx +++ b/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx @@ -1,4 +1,4 @@ -cast_bfloat16_float.py:U +cast_bfloat16_float.py:U ! input1output"Cast* to castZ @@ -9,4 +9,4 @@ output -B
\ No newline at end of file +B
\ No newline at end of file diff --git a/model-integration/src/test/models/onnx/cast_bfloat16_float.py b/model-integration/src/test/models/onnx/cast_bfloat16_float.py index 51d04747958..952e4c469c1 100755 --- a/model-integration/src/test/models/onnx/cast_bfloat16_float.py +++ b/model-integration/src/test/models/onnx/cast_bfloat16_float.py @@ -20,5 +20,5 @@ graph_def = helper.make_graph( [INPUT_1], [OUTPUT], ) -model_def = helper.make_model(graph_def, producer_name='cast_bfloat16_float.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) +model_def = helper.make_model(graph_def, producer_name='cast_bfloat16_float.py', opset_imports=[onnx.OperatorSetIdProto(version=19)]) onnx.save(model_def, 'cast_bfloat16_float.onnx') |