diff options
author | Lester Solbakken <lesters@oath.com> | 2021-08-20 10:46:45 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-08-20 10:46:45 +0200 |
commit | dc64fa828fc14983a829b1fbd9e66b6c1ebb30ac (patch) | |
tree | 0e9a34e013968fa271e6dae53251f9df6bbb4675 /model-integration | |
parent | 9fefa6305f53fa1e965c3578211c48607f87d64b (diff) |
Excplicity close native onnx tensors used as input
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 59ad20b7714..a306d09b3c1 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -35,19 +35,25 @@ public class OnnxEvaluator { } public Tensor evaluate(Map<String, Tensor> inputs, String output) { + Map<String, OnnxTensor> onnxInputs = null; try { - Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); + onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) { return TensorConverter.toVespaTensor(result.get(0)); } } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); + } finally { + if (onnxInputs != null) { + onnxInputs.values().forEach(OnnxTensor::close); + } } } public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) { + Map<String, OnnxTensor> onnxInputs = null; try { - Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); + onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); Map<String, Tensor> outputs = new HashMap<>(); try (OrtSession.Result result = session.run(onnxInputs)) { for (Map.Entry<String, OnnxValue> output : result) { @@ -57,6 +63,10 @@ public class OnnxEvaluator { } } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); + } finally { + if (onnxInputs != null) { + onnxInputs.values().forEach(OnnxTensor::close); + } } } |