From 914cad21b94a09f2ec340572491681eba8108834 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Thu, 18 Jan 2024 23:46:16 +0100 Subject: Cache sizes.totalSize() in variable to prevent recomputation. --- .../evaluator/TensorConverter.java | 39 +++++++++++----------- 1 file changed, 19 insertions(+), 20 deletions(-) (limited to 'model-integration') diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 2612702e99b..07f2aea4ab6 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -53,10 +53,9 @@ class TensorConverter { static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment) throws OrtException { - if ( ! (vespaTensor instanceof IndexedTensor)) { + if ( ! (vespaTensor instanceof IndexedTensor tensor)) { throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions"); } - IndexedTensor tensor = (IndexedTensor) vespaTensor; ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder()); if (onnxTensorInfo.type == OnnxJavaType.FLOAT) { for (int i = 0; i < tensor.size(); i++) @@ -103,54 +102,54 @@ class TensorConverter { } static Tensor toVespaTensor(OnnxValue onnxValue) { - if ( ! (onnxValue instanceof OnnxTensor)) { + if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) { throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); } - OnnxTensor onnxTensor = (OnnxTensor) onnxValue; TensorInfo tensorInfo = onnxTensor.getInfo(); TensorType type = toVespaType(onnxTensor.getInfo()); DimensionSizes sizes = sizesFromType(type); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes); + long totalSize = sizes.totalSize(); if (tensorInfo.type == OnnxJavaType.FLOAT) { FloatBuffer buffer = onnxTensor.getFloatBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.DOUBLE) { DoubleBuffer buffer = onnxTensor.getDoubleBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.INT8) { ByteBuffer buffer = onnxTensor.getByteBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.INT16) { ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.INT32) { IntBuffer buffer = onnxTensor.getIntBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.INT64) { LongBuffer buffer = onnxTensor.getLongBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.FLOAT16) { ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get())); } else if (tensorInfo.type == OnnxJavaType.BFLOAT16) { ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get()))); } else { @@ -201,14 +200,14 @@ class TensorConverter { } static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) { - switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return TensorType.Value.FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; - } - return TensorType.Value.DOUBLE; + return switch (onnxType) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 -> TensorType.Value.INT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 -> TensorType.Value.BFLOAT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 -> TensorType.Value.FLOAT; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT -> TensorType.Value.FLOAT; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE -> TensorType.Value.DOUBLE; + default -> TensorType.Value.DOUBLE; + }; } static private TensorInfo toTensorInfo(ValueInfo valueInfo) { -- cgit v1.2.3