aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java79
1 files changed, 0 insertions, 79 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
deleted file mode 100644
index 59ad20b7714..00000000000
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ /dev/null
@@ -1,79 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.modelintegration.evaluator;
-
-import ai.onnxruntime.OnnxTensor;
-import ai.onnxruntime.OnnxValue;
-import ai.onnxruntime.OrtEnvironment;
-import ai.onnxruntime.OrtException;
-import ai.onnxruntime.OrtSession;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-
-
-/**
- * Evaluates an ONNX Model by deferring to ONNX Runtime.
- *
- * @author lesters
- */
-public class OnnxEvaluator {
-
- private final OrtEnvironment environment;
- private final OrtSession session;
-
- public OnnxEvaluator(String modelPath) {
- try {
- environment = OrtEnvironment.getEnvironment();
- session = environment.createSession(modelPath, new OrtSession.SessionOptions());
- } catch (OrtException e) {
- throw new RuntimeException("ONNX Runtime exception", e);
- }
- }
-
- public Tensor evaluate(Map<String, Tensor> inputs, String output) {
- try {
- Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
- try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) {
- return TensorConverter.toVespaTensor(result.get(0));
- }
- } catch (OrtException e) {
- throw new RuntimeException("ONNX Runtime exception", e);
- }
- }
-
- public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
- try {
- Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
- Map<String, Tensor> outputs = new HashMap<>();
- try (OrtSession.Result result = session.run(onnxInputs)) {
- for (Map.Entry<String, OnnxValue> output : result) {
- outputs.put(output.getKey(), TensorConverter.toVespaTensor(output.getValue()));
- }
- return outputs;
- }
- } catch (OrtException e) {
- throw new RuntimeException("ONNX Runtime exception", e);
- }
- }
-
- public Map<String, TensorType> getInputInfo() {
- try {
- return TensorConverter.toVespaTypes(session.getInputInfo());
- } catch (OrtException e) {
- throw new RuntimeException("ONNX Runtime exception", e);
- }
- }
-
- public Map<String, TensorType> getOutputInfo() {
- try {
- return TensorConverter.toVespaTypes(session.getOutputInfo());
- } catch (OrtException e) {
- throw new RuntimeException("ONNX Runtime exception", e);
- }
- }
-
-}