diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-10-26 22:15:42 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-10-26 22:15:42 +0200 |
commit | f8da735cc1116d5036d20a3f767f4e3b69d2f72b (patch) | |
tree | 3e3caf09898ec2688b0eeacf40b09819068f1c69 /model-integration/src/main/java/ai | |
parent | 14b16ae550784fed98910e04188429579e289f2d (diff) |
Add support and upgrade opset
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 952adc36621..2612702e99b 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import ai.onnxruntime.TensorInfo; import ai.onnxruntime.ValueInfo; +import ai.onnxruntime.platform.Fp16Conversions; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; @@ -56,7 +57,6 @@ class TensorConverter { throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions"); } IndexedTensor tensor = (IndexedTensor) vespaTensor; - ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder()); if (onnxTensorInfo.type == OnnxJavaType.FLOAT) { for (int i = 0; i < tensor.size(); i++) @@ -88,6 +88,17 @@ class TensorConverter { buffer.putLong((long) tensor.get(i)); return OnnxTensor.createTensor(environment, buffer.rewind().asLongBuffer(), tensor.shape()); } + if (onnxTensorInfo.type == OnnxJavaType.FLOAT16) { + for (int i = 0; i < tensor.size(); i++) { + buffer.putShort(Fp16Conversions.floatToFp16((float)tensor.get(i))); + } + return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape(), OnnxJavaType.FLOAT16); + } + if (onnxTensorInfo.type == OnnxJavaType.BFLOAT16) { + for (int i = 0; i < tensor.size(); i++) + buffer.putShort(Fp16Conversions.floatToBf16((float)tensor.get(i))); + return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape(), OnnxJavaType.BFLOAT16); + } throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type); } @@ -132,6 +143,16 @@ class TensorConverter { for (long i = 0; i < sizes.totalSize(); i++) builder.cellByDirectIndex(i, buffer.get()); } + else if (tensorInfo.type == OnnxJavaType.FLOAT16) { + ShortBuffer buffer = onnxTensor.getShortBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get())); + } + else if (tensorInfo.type == OnnxJavaType.BFLOAT16) { + ShortBuffer buffer = onnxTensor.getShortBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get()))); + } else { throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); } @@ -183,6 +204,7 @@ class TensorConverter { switch (onnxType) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return TensorType.Value.FLOAT; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; } |