aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-10-26 22:15:42 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2023-10-26 22:15:42 +0200
commitf8da735cc1116d5036d20a3f767f4e3b69d2f72b (patch)
tree3e3caf09898ec2688b0eeacf40b09819068f1c69 /model-integration
parent14b16ae550784fed98910e04188429579e289f2d (diff)
Add support and upgrade opset
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java24
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java8
-rw-r--r--model-integration/src/test/models/onnx/cast_bfloat16_float.onnx4
-rwxr-xr-xmodel-integration/src/test/models/onnx/cast_bfloat16_float.py2
4 files changed, 31 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index 952adc36621..2612702e99b 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
+import ai.onnxruntime.platform.Fp16Conversions;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
@@ -56,7 +57,6 @@ class TensorConverter {
throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions");
}
IndexedTensor tensor = (IndexedTensor) vespaTensor;
-
ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder());
if (onnxTensorInfo.type == OnnxJavaType.FLOAT) {
for (int i = 0; i < tensor.size(); i++)
@@ -88,6 +88,17 @@ class TensorConverter {
buffer.putLong((long) tensor.get(i));
return OnnxTensor.createTensor(environment, buffer.rewind().asLongBuffer(), tensor.shape());
}
+ if (onnxTensorInfo.type == OnnxJavaType.FLOAT16) {
+ for (int i = 0; i < tensor.size(); i++) {
+ buffer.putShort(Fp16Conversions.floatToFp16((float)tensor.get(i)));
+ }
+ return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape(), OnnxJavaType.FLOAT16);
+ }
+ if (onnxTensorInfo.type == OnnxJavaType.BFLOAT16) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putShort(Fp16Conversions.floatToBf16((float)tensor.get(i)));
+ return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape(), OnnxJavaType.BFLOAT16);
+ }
throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type);
}
@@ -132,6 +143,16 @@ class TensorConverter {
for (long i = 0; i < sizes.totalSize(); i++)
builder.cellByDirectIndex(i, buffer.get());
}
+ else if (tensorInfo.type == OnnxJavaType.FLOAT16) {
+ ShortBuffer buffer = onnxTensor.getShortBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get()));
+ }
+ else if (tensorInfo.type == OnnxJavaType.BFLOAT16) {
+ ShortBuffer buffer = onnxTensor.getShortBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get())));
+ }
else {
throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type);
}
@@ -183,6 +204,7 @@ class TensorConverter {
switch (onnxType) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return TensorType.Value.FLOAT;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE;
}
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 9bb01fc8073..75da4d163bb 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -90,12 +90,14 @@ public class OnnxEvaluatorTest {
var runtime = new OnnxRuntime();
assertEvaluate(runtime, "add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
assertEvaluate(runtime, "add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
+ assertEvaluate(runtime, "add_float16.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
+ //Add is not a supported operation for bfloat16 types in onnx operators.
+ assertEvaluate(runtime, "sign_bfloat16.onnx", "tensor<bfloat16>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]");
+
assertEvaluate(runtime, "add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
assertEvaluate(runtime, "cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]");
assertEvaluate(runtime, "cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]");
-
- // ONNX Runtime 1.8.0 does not support much of bfloat16 yet
- // assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]");
+ assertEvaluate(runtime,"cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]");
}
@Test
diff --git a/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx b/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx
index cb19592abf4..9fcbd7f1b3c 100644
--- a/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx
+++ b/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx
@@ -1,4 +1,4 @@
-cast_bfloat16_float.py:U
+cast_bfloat16_float.py:U
!
input1output"Cast*
to castZ
@@ -9,4 +9,4 @@
output

-B \ No newline at end of file
+B \ No newline at end of file
diff --git a/model-integration/src/test/models/onnx/cast_bfloat16_float.py b/model-integration/src/test/models/onnx/cast_bfloat16_float.py
index 51d04747958..952e4c469c1 100755
--- a/model-integration/src/test/models/onnx/cast_bfloat16_float.py
+++ b/model-integration/src/test/models/onnx/cast_bfloat16_float.py
@@ -20,5 +20,5 @@ graph_def = helper.make_graph(
[INPUT_1],
[OUTPUT],
)
-model_def = helper.make_model(graph_def, producer_name='cast_bfloat16_float.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+model_def = helper.make_model(graph_def, producer_name='cast_bfloat16_float.py', opset_imports=[onnx.OperatorSetIdProto(version=19)])
onnx.save(model_def, 'cast_bfloat16_float.onnx')