aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 23:46:16 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 23:46:16 +0100
commit914cad21b94a09f2ec340572491681eba8108834 (patch)
tree128e2c509e250ec483e0f46c73a11e14bc939592 /model-integration
parentac42a2d1ab601e0b494fdbee46e96fc92c00fe13 (diff)
Cache sizes.totalSize() in variable to prevent recomputation.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java39
1 files changed, 19 insertions, 20 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index 2612702e99b..07f2aea4ab6 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -53,10 +53,9 @@ class TensorConverter {
static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment)
throws OrtException
{
- if ( ! (vespaTensor instanceof IndexedTensor)) {
+ if ( ! (vespaTensor instanceof IndexedTensor tensor)) {
throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions");
}
- IndexedTensor tensor = (IndexedTensor) vespaTensor;
ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder());
if (onnxTensorInfo.type == OnnxJavaType.FLOAT) {
for (int i = 0; i < tensor.size(); i++)
@@ -103,54 +102,54 @@ class TensorConverter {
}
static Tensor toVespaTensor(OnnxValue onnxValue) {
- if ( ! (onnxValue instanceof OnnxTensor)) {
+ if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) {
throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
}
- OnnxTensor onnxTensor = (OnnxTensor) onnxValue;
TensorInfo tensorInfo = onnxTensor.getInfo();
TensorType type = toVespaType(onnxTensor.getInfo());
DimensionSizes sizes = sizesFromType(type);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes);
+ long totalSize = sizes.totalSize();
if (tensorInfo.type == OnnxJavaType.FLOAT) {
FloatBuffer buffer = onnxTensor.getFloatBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.DOUBLE) {
DoubleBuffer buffer = onnxTensor.getDoubleBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.INT8) {
ByteBuffer buffer = onnxTensor.getByteBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.INT16) {
ShortBuffer buffer = onnxTensor.getShortBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.INT32) {
IntBuffer buffer = onnxTensor.getIntBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.INT64) {
LongBuffer buffer = onnxTensor.getLongBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.FLOAT16) {
ShortBuffer buffer = onnxTensor.getShortBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get()));
}
else if (tensorInfo.type == OnnxJavaType.BFLOAT16) {
ShortBuffer buffer = onnxTensor.getShortBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get())));
}
else {
@@ -201,14 +200,14 @@ class TensorConverter {
}
static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) {
- switch (onnxType) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return TensorType.Value.FLOAT;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE;
- }
- return TensorType.Value.DOUBLE;
+ return switch (onnxType) {
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 -> TensorType.Value.INT8;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 -> TensorType.Value.BFLOAT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 -> TensorType.Value.FLOAT;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT -> TensorType.Value.FLOAT;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE -> TensorType.Value.DOUBLE;
+ default -> TensorType.Value.DOUBLE;
+ };
}
static private TensorInfo toTensorInfo(ValueInfo valueInfo) {