diff options
Diffstat (limited to 'model-integration/src/main')
3 files changed, 0 insertions, 265 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java deleted file mode 100644 index 59ad20b7714..00000000000 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.modelintegration.evaluator; - -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - - -/** - * Evaluates an ONNX Model by deferring to ONNX Runtime. - * - * @author lesters - */ -public class OnnxEvaluator { - - private final OrtEnvironment environment; - private final OrtSession session; - - public OnnxEvaluator(String modelPath) { - try { - environment = OrtEnvironment.getEnvironment(); - session = environment.createSession(modelPath, new OrtSession.SessionOptions()); - } catch (OrtException e) { - throw new RuntimeException("ONNX Runtime exception", e); - } - } - - public Tensor evaluate(Map<String, Tensor> inputs, String output) { - try { - Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); - try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) { - return TensorConverter.toVespaTensor(result.get(0)); - } - } catch (OrtException e) { - throw new RuntimeException("ONNX Runtime exception", e); - } - } - - public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) { - try { - Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); - Map<String, Tensor> outputs = new HashMap<>(); - try (OrtSession.Result result = session.run(onnxInputs)) { - for (Map.Entry<String, OnnxValue> output : result) { - outputs.put(output.getKey(), TensorConverter.toVespaTensor(output.getValue())); - } - return outputs; - } - } catch (OrtException e) { - throw new RuntimeException("ONNX Runtime exception", e); - } - } - - public Map<String, TensorType> getInputInfo() { - try { - return TensorConverter.toVespaTypes(session.getInputInfo()); - } catch (OrtException e) { - throw new RuntimeException("ONNX Runtime exception", e); - } - } - - public Map<String, TensorType> getOutputInfo() { - try { - return TensorConverter.toVespaTypes(session.getOutputInfo()); - } catch (OrtException e) { - throw new RuntimeException("ONNX Runtime exception", e); - } - } - -} diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java deleted file mode 100644 index c1f973300d6..00000000000 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.modelintegration.evaluator; - -import ai.onnxruntime.NodeInfo; -import ai.onnxruntime.OnnxJavaType; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession; -import ai.onnxruntime.TensorInfo; -import ai.onnxruntime.ValueInfo; -import com.yahoo.tensor.DimensionSizes; -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; -import java.nio.ShortBuffer; -import java.util.HashMap; -import java.util.Map; -import java.util.stream.Collectors; - - -/** - * @author lesters - */ -class TensorConverter { - - static Map<String, OnnxTensor> toOnnxTensors(Map<String, Tensor> tensorMap, OrtEnvironment env, OrtSession session) - throws OrtException - { - Map<String, OnnxTensor> result = new HashMap<>(); - for (String name : tensorMap.keySet()) { - Tensor vespaTensor = tensorMap.get(name); - TensorInfo onnxTensorInfo = toTensorInfo(session.getInputInfo().get(name).getInfo()); - OnnxTensor onnxTensor = toOnnxTensor(vespaTensor, onnxTensorInfo, env); - result.put(name, onnxTensor); - } - return result; - } - - static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment) - throws OrtException - { - if ( ! (vespaTensor instanceof IndexedTensor)) { - throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions"); - } - IndexedTensor tensor = (IndexedTensor) vespaTensor; - - ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder()); - if (onnxTensorInfo.type == OnnxJavaType.FLOAT) { - for (int i = 0; i < tensor.size(); i++) - buffer.putFloat(tensor.getFloat(i)); - return OnnxTensor.createTensor(environment, buffer.rewind().asFloatBuffer(), tensor.shape()); - } - if (onnxTensorInfo.type == OnnxJavaType.DOUBLE) { - for (int i = 0; i < tensor.size(); i++) - buffer.putDouble(tensor.get(i)); - return OnnxTensor.createTensor(environment, buffer.rewind().asDoubleBuffer(), tensor.shape()); - } - if (onnxTensorInfo.type == OnnxJavaType.INT8) { - for (int i = 0; i < tensor.size(); i++) - buffer.put((byte) tensor.get(i)); - return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape()); - } - if (onnxTensorInfo.type == OnnxJavaType.INT16) { - for (int i = 0; i < tensor.size(); i++) - buffer.putShort((short) tensor.get(i)); - return OnnxTensor.createTensor(environment, buffer.rewind().asShortBuffer(), tensor.shape()); - } - if (onnxTensorInfo.type == OnnxJavaType.INT32) { - for (int i = 0; i < tensor.size(); i++) - buffer.putInt((int) tensor.get(i)); - return OnnxTensor.createTensor(environment, buffer.rewind().asIntBuffer(), tensor.shape()); - } - if (onnxTensorInfo.type == OnnxJavaType.INT64) { - for (int i = 0; i < tensor.size(); i++) - buffer.putLong((long) tensor.get(i)); - return OnnxTensor.createTensor(environment, buffer.rewind().asLongBuffer(), tensor.shape()); - } - throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type); - } - - static Tensor toVespaTensor(OnnxValue onnxValue) { - if ( ! (onnxValue instanceof OnnxTensor)) { - throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); - } - OnnxTensor onnxTensor = (OnnxTensor) onnxValue; - TensorInfo tensorInfo = onnxTensor.getInfo(); - - TensorType type = toVespaType(onnxTensor.getInfo()); - DimensionSizes sizes = sizesFromType(type); - - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes); - if (tensorInfo.type == OnnxJavaType.FLOAT) { - FloatBuffer buffer = onnxTensor.getFloatBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.DOUBLE) { - DoubleBuffer buffer = onnxTensor.getDoubleBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT8) { - ByteBuffer buffer = onnxTensor.getByteBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT32) { - IntBuffer buffer = onnxTensor.getIntBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT64) { - LongBuffer buffer = onnxTensor.getLongBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else { - throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); - } - return builder.build(); - } - - static private DimensionSizes sizesFromType(TensorType type) { - DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); - for (int i = 0; i < type.dimensions().size(); i++) - builder.set(i, type.dimensions().get(i).size().get()); - return builder.build(); - } - - static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { - return infoMap.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> toVespaType(e.getValue().getInfo()))); - } - - static TensorType toVespaType(ValueInfo valueInfo) { - TensorInfo tensorInfo = toTensorInfo(valueInfo); - TensorType.Builder builder = new TensorType.Builder(toVespaValueType(tensorInfo.onnxType)); - long[] shape = tensorInfo.getShape(); - for (int i = 0; i < shape.length; ++i) { - long dimSize = shape[i]; - String dimName = "d" + i; // standard naming convention - if (dimSize > 0) - builder.indexed(dimName, dimSize); - else - builder.indexed(dimName); // unbound dimension for dim size -1 - } - return builder.build(); - } - - static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) { - switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; - } - return TensorType.Value.DOUBLE; - } - - static private TensorInfo toTensorInfo(ValueInfo valueInfo) { - if ( ! (valueInfo instanceof TensorInfo)) { - throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); - } - return (TensorInfo) valueInfo; - } - -} diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java deleted file mode 100644 index e44ea96c534..00000000000 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -package ai.vespa.modelintegration.evaluator; - -import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file |