aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-08-20 10:46:45 +0200
committerLester Solbakken <lesters@oath.com>2021-08-20 10:46:45 +0200
commitdc64fa828fc14983a829b1fbd9e66b6c1ebb30ac (patch)
tree0e9a34e013968fa271e6dae53251f9df6bbb4675 /model-integration
parent9fefa6305f53fa1e965c3578211c48607f87d64b (diff)
Excplicity close native onnx tensors used as input
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java14
1 files changed, 12 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 59ad20b7714..a306d09b3c1 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -35,19 +35,25 @@ public class OnnxEvaluator {
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
+ Map<String, OnnxTensor> onnxInputs = null;
try {
- Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
+ onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) {
return TensorConverter.toVespaTensor(result.get(0));
}
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
+ } finally {
+ if (onnxInputs != null) {
+ onnxInputs.values().forEach(OnnxTensor::close);
+ }
}
}
public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
+ Map<String, OnnxTensor> onnxInputs = null;
try {
- Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
+ onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
Map<String, Tensor> outputs = new HashMap<>();
try (OrtSession.Result result = session.run(onnxInputs)) {
for (Map.Entry<String, OnnxValue> output : result) {
@@ -57,6 +63,10 @@ public class OnnxEvaluator {
}
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
+ } finally {
+ if (onnxInputs != null) {
+ onnxInputs.values().forEach(OnnxTensor::close);
+ }
}
}