diff options
author | Lester Solbakken <lesters@oath.com> | 2021-05-19 14:41:57 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-05-19 14:41:57 +0200 |
commit | 77807b53619a6b1449610c598f391b024eb52930 (patch) | |
tree | 65117eab0d5e5cafee49a04b10d1c79f1e3431fa /model-integration/src/main/java/ai | |
parent | bd35a66573c5b6cbf05f2d875cef00817b7d23c1 (diff) |
Revert "Revert "Add ONNX-RT evaluator to model-integration module""
This reverts commit 97080252fac0ba45b58f9d0efb56603da518428f.
Diffstat (limited to 'model-integration/src/main/java/ai')
3 files changed, 265 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java new file mode 100644 index 00000000000..59ad20b7714 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -0,0 +1,79 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + + +/** + * Evaluates an ONNX Model by deferring to ONNX Runtime. + * + * @author lesters + */ +public class OnnxEvaluator { + + private final OrtEnvironment environment; + private final OrtSession session; + + public OnnxEvaluator(String modelPath) { + try { + environment = OrtEnvironment.getEnvironment(); + session = environment.createSession(modelPath, new OrtSession.SessionOptions()); + } catch (OrtException e) { + throw new RuntimeException("ONNX Runtime exception", e); + } + } + + public Tensor evaluate(Map<String, Tensor> inputs, String output) { + try { + Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); + try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) { + return TensorConverter.toVespaTensor(result.get(0)); + } + } catch (OrtException e) { + throw new RuntimeException("ONNX Runtime exception", e); + } + } + + public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) { + try { + Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); + Map<String, Tensor> outputs = new HashMap<>(); + try (OrtSession.Result result = session.run(onnxInputs)) { + for (Map.Entry<String, OnnxValue> output : result) { + outputs.put(output.getKey(), TensorConverter.toVespaTensor(output.getValue())); + } + return outputs; + } + } catch (OrtException e) { + throw new RuntimeException("ONNX Runtime exception", e); + } + } + + public Map<String, TensorType> getInputInfo() { + try { + return TensorConverter.toVespaTypes(session.getInputInfo()); + } catch (OrtException e) { + throw new RuntimeException("ONNX Runtime exception", e); + } + } + + public Map<String, TensorType> getOutputInfo() { + try { + return TensorConverter.toVespaTypes(session.getOutputInfo()); + } catch (OrtException e) { + throw new RuntimeException("ONNX Runtime exception", e); + } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java new file mode 100644 index 00000000000..c1f973300d6 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -0,0 +1,181 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.NodeInfo; +import ai.onnxruntime.OnnxJavaType; +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import ai.onnxruntime.TensorInfo; +import ai.onnxruntime.ValueInfo; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.nio.ShortBuffer; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + + +/** + * @author lesters + */ +class TensorConverter { + + static Map<String, OnnxTensor> toOnnxTensors(Map<String, Tensor> tensorMap, OrtEnvironment env, OrtSession session) + throws OrtException + { + Map<String, OnnxTensor> result = new HashMap<>(); + for (String name : tensorMap.keySet()) { + Tensor vespaTensor = tensorMap.get(name); + TensorInfo onnxTensorInfo = toTensorInfo(session.getInputInfo().get(name).getInfo()); + OnnxTensor onnxTensor = toOnnxTensor(vespaTensor, onnxTensorInfo, env); + result.put(name, onnxTensor); + } + return result; + } + + static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment) + throws OrtException + { + if ( ! (vespaTensor instanceof IndexedTensor)) { + throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions"); + } + IndexedTensor tensor = (IndexedTensor) vespaTensor; + + ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder()); + if (onnxTensorInfo.type == OnnxJavaType.FLOAT) { + for (int i = 0; i < tensor.size(); i++) + buffer.putFloat(tensor.getFloat(i)); + return OnnxTensor.createTensor(environment, buffer.rewind().asFloatBuffer(), tensor.shape()); + } + if (onnxTensorInfo.type == OnnxJavaType.DOUBLE) { + for (int i = 0; i < tensor.size(); i++) + buffer.putDouble(tensor.get(i)); + return OnnxTensor.createTensor(environment, buffer.rewind().asDoubleBuffer(), tensor.shape()); + } + if (onnxTensorInfo.type == OnnxJavaType.INT8) { + for (int i = 0; i < tensor.size(); i++) + buffer.put((byte) tensor.get(i)); + return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape()); + } + if (onnxTensorInfo.type == OnnxJavaType.INT16) { + for (int i = 0; i < tensor.size(); i++) + buffer.putShort((short) tensor.get(i)); + return OnnxTensor.createTensor(environment, buffer.rewind().asShortBuffer(), tensor.shape()); + } + if (onnxTensorInfo.type == OnnxJavaType.INT32) { + for (int i = 0; i < tensor.size(); i++) + buffer.putInt((int) tensor.get(i)); + return OnnxTensor.createTensor(environment, buffer.rewind().asIntBuffer(), tensor.shape()); + } + if (onnxTensorInfo.type == OnnxJavaType.INT64) { + for (int i = 0; i < tensor.size(); i++) + buffer.putLong((long) tensor.get(i)); + return OnnxTensor.createTensor(environment, buffer.rewind().asLongBuffer(), tensor.shape()); + } + throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type); + } + + static Tensor toVespaTensor(OnnxValue onnxValue) { + if ( ! (onnxValue instanceof OnnxTensor)) { + throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); + } + OnnxTensor onnxTensor = (OnnxTensor) onnxValue; + TensorInfo tensorInfo = onnxTensor.getInfo(); + + TensorType type = toVespaType(onnxTensor.getInfo()); + DimensionSizes sizes = sizesFromType(type); + + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes); + if (tensorInfo.type == OnnxJavaType.FLOAT) { + FloatBuffer buffer = onnxTensor.getFloatBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.get()); + } + else if (tensorInfo.type == OnnxJavaType.DOUBLE) { + DoubleBuffer buffer = onnxTensor.getDoubleBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.get()); + } + else if (tensorInfo.type == OnnxJavaType.INT8) { + ByteBuffer buffer = onnxTensor.getByteBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.get()); + } + else if (tensorInfo.type == OnnxJavaType.INT16) { + ShortBuffer buffer = onnxTensor.getShortBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.get()); + } + else if (tensorInfo.type == OnnxJavaType.INT32) { + IntBuffer buffer = onnxTensor.getIntBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.get()); + } + else if (tensorInfo.type == OnnxJavaType.INT64) { + LongBuffer buffer = onnxTensor.getLongBuffer(); + for (long i = 0; i < sizes.totalSize(); i++) + builder.cellByDirectIndex(i, buffer.get()); + } + else { + throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); + } + return builder.build(); + } + + static private DimensionSizes sizesFromType(TensorType type) { + DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); + for (int i = 0; i < type.dimensions().size(); i++) + builder.set(i, type.dimensions().get(i).size().get()); + return builder.build(); + } + + static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { + return infoMap.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> toVespaType(e.getValue().getInfo()))); + } + + static TensorType toVespaType(ValueInfo valueInfo) { + TensorInfo tensorInfo = toTensorInfo(valueInfo); + TensorType.Builder builder = new TensorType.Builder(toVespaValueType(tensorInfo.onnxType)); + long[] shape = tensorInfo.getShape(); + for (int i = 0; i < shape.length; ++i) { + long dimSize = shape[i]; + String dimName = "d" + i; // standard naming convention + if (dimSize > 0) + builder.indexed(dimName, dimSize); + else + builder.indexed(dimName); // unbound dimension for dim size -1 + } + return builder.build(); + } + + static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) { + switch (onnxType) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; + } + return TensorType.Value.DOUBLE; + } + + static private TensorInfo toTensorInfo(ValueInfo valueInfo) { + if ( ! (valueInfo instanceof TensorInfo)) { + throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); + } + return (TensorInfo) valueInfo; + } + +} diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java new file mode 100644 index 00000000000..e44ea96c534 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java @@ -0,0 +1,5 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +package ai.vespa.modelintegration.evaluator; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file |