aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-10-26 22:15:42 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2023-10-26 22:15:42 +0200
commitf8da735cc1116d5036d20a3f767f4e3b69d2f72b (patch)
tree3e3caf09898ec2688b0eeacf40b09819068f1c69 /model-integration/src/main/java/ai
parent14b16ae550784fed98910e04188429579e289f2d (diff)
Add support and upgrade opset
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java24
1 files changed, 23 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index 952adc36621..2612702e99b 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
+import ai.onnxruntime.platform.Fp16Conversions;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
@@ -56,7 +57,6 @@ class TensorConverter {
throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions");
}
IndexedTensor tensor = (IndexedTensor) vespaTensor;
-
ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder());
if (onnxTensorInfo.type == OnnxJavaType.FLOAT) {
for (int i = 0; i < tensor.size(); i++)
@@ -88,6 +88,17 @@ class TensorConverter {
buffer.putLong((long) tensor.get(i));
return OnnxTensor.createTensor(environment, buffer.rewind().asLongBuffer(), tensor.shape());
}
+ if (onnxTensorInfo.type == OnnxJavaType.FLOAT16) {
+ for (int i = 0; i < tensor.size(); i++) {
+ buffer.putShort(Fp16Conversions.floatToFp16((float)tensor.get(i)));
+ }
+ return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape(), OnnxJavaType.FLOAT16);
+ }
+ if (onnxTensorInfo.type == OnnxJavaType.BFLOAT16) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putShort(Fp16Conversions.floatToBf16((float)tensor.get(i)));
+ return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape(), OnnxJavaType.BFLOAT16);
+ }
throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type);
}
@@ -132,6 +143,16 @@ class TensorConverter {
for (long i = 0; i < sizes.totalSize(); i++)
builder.cellByDirectIndex(i, buffer.get());
}
+ else if (tensorInfo.type == OnnxJavaType.FLOAT16) {
+ ShortBuffer buffer = onnxTensor.getShortBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get()));
+ }
+ else if (tensorInfo.type == OnnxJavaType.BFLOAT16) {
+ ShortBuffer buffer = onnxTensor.getShortBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get())));
+ }
else {
throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type);
}
@@ -183,6 +204,7 @@ class TensorConverter {
switch (onnxType) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return TensorType.Value.FLOAT;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE;
}