aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 17:02:23 +0100
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 18:13:08 +0100
commit5271b5d7241aa2aa2538b2072b8cae9b8f3d689a (patch)
tree12f025b12e86e5f9490b74dd2cae68283f779e67
parent6b40c6053b8542ae20a5bbe669f84f2d478fd697 (diff)
Replace `OnnxEvaluatorCache` with OnnxRuntime
Require an `OnnxRuntime` instance to create `OnnxEvaluator` instances. Cache underlying `OrtSession` instead of `OnnxEvaluator`. Move static helpers for checking Onnx runtime availability from `OnnxEvaluator` to `OnnxRuntime`.
-rw-r--r--application/pom.xml6
-rw-r--r--application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java10
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java4
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java9
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java8
-rw-r--r--model-evaluation/abi-spec.json4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java8
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java24
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java10
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java4
-rw-r--r--model-integration/pom.xml6
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java23
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/Generator.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java68
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java88
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java170
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java15
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java19
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java16
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java10
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java38
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java28
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java48
28 files changed, 380 insertions, 274 deletions
diff --git a/application/pom.xml b/application/pom.xml
index 2193f0fe2e3..236bcb6d81a 100644
--- a/application/pom.xml
+++ b/application/pom.xml
@@ -182,6 +182,12 @@
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <!-- Required for ContainerModelEvaluationTest -->
+ <groupId>com.microsoft.onnxruntime</groupId>
+ <artifactId>onnxruntime</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
diff --git a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
index f838d7a5481..cd5fd42a81a 100644
--- a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
+++ b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.application.container;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.application.Application;
import com.yahoo.application.Networking;
import com.yahoo.application.container.handler.Request;
@@ -40,7 +40,7 @@ public class ContainerModelEvaluationTest {
@Test
void testCreateApplicationInstanceWithModelEvaluation() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
try (Application application =
Application.fromApplicationPackage(new File("src/test/app-packages/model-evaluation"),
Networking.disable)) {
@@ -54,17 +54,17 @@ public class ContainerModelEvaluationTest {
}
{
- String expected = "{\"cells\":[{\"address\":{},\"value\":2.496898}]}";
+ String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":2.496898}]}";
assertResponse("http://localhost/model-evaluation/v1/xgboost_xgboost_2_2/eval?format.tensors=long", expected, jdisc);
}
{
- String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}";
+ String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}";
assertResponse("http://localhost/model-evaluation/v1/lightgbm_regression/eval?format.tensors=long", expected, jdisc);
}
{
- String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.3671652674674988}]}";
+ String expected = "{\"type\":\"tensor<float>(d0[3])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.36716532707214355}]}";
assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc);
assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default.output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc);
assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default/output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
index 57110b2431e..3d9a8441ed5 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
@@ -44,7 +44,7 @@ public class ContainerModelEvaluation implements
public ContainerModelEvaluation(ApplicationContainerCluster cluster, RankProfileList rankProfileList) {
this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null");
cluster.addSimpleComponent(EVALUATOR_NAME, null, EVALUATION_BUNDLE_NAME);
- cluster.addSimpleComponent("ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache", null, INTEGRATION_BUNDLE_NAME);
+ cluster.addSimpleComponent("ai.vespa.modelintegration.evaluator.OnnxRuntime", null, INTEGRATION_BUNDLE_NAME);
cluster.addComponent(ContainerModelEvaluation.getHandler());
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
index 063f8f3109e..5b6c7b97875 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.ml;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.tensor.Tensor;
@@ -21,7 +21,7 @@ public class ModelsEvaluatorTest {
void testModelsEvaluator() {
// Assumption fails but test passes on Intel macs
// Assumption fails and test fails on ARM64
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
ModelsEvaluator modelsEvaluator = ModelsEvaluatorTester.create("src/test/cfg/application/stateless_eval");
assertEquals(3, modelsEvaluator.models().size());
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index caf0d22d44e..fc70a65b394 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
import ai.vespa.models.handler.ModelsEvaluationHandler;
@@ -27,7 +27,10 @@ import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
-import static org.junit.jupiter.api.Assertions.*;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
/**
@@ -60,7 +63,7 @@ public class ModelEvaluationTest {
@Test
void testMl_serving() throws IOException {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
Path appDir = Path.fromString("src/test/cfg/application/ml_serving");
Path storedAppDir = appDir.append("copy");
try {
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
index a731e9c7ccc..b0fe2c09227 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
@@ -45,7 +45,7 @@ public class StatelessOnnxEvaluationTest {
@Test
void testStatelessOnnxModelNameCollision() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
Path appDir = Path.fromString("src/test/cfg/application/onnx_name_collision");
try {
ImportedModelTester tester = new ImportedModelTester("onnx", appDir);
@@ -66,7 +66,7 @@ public class StatelessOnnxEvaluationTest {
@Test
void testStatelessOnnxModelEvaluation() throws Exception {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
Path appDir = Path.fromString("src/test/cfg/application/onnx");
Path storedAppDir = appDir.append("copy");
try {
@@ -91,7 +91,7 @@ public class StatelessOnnxEvaluationTest {
@Test
void testStatelessOnnxModelEvaluationWithGpu() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
NodeResources resources = new NodeResources(4, 16, 125, 10,
NodeResources.DiskSpeed.fast, NodeResources.StorageType.local,
NodeResources.Architecture.x86_64,
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index 9fd25ac115b..667712d0daa 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -70,7 +70,7 @@
"public"
],
"methods" : [
- "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)",
+ "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxRuntime)",
"public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
"public void <init>(ai.vespa.models.evaluation.RankProfilesConfigImporter, com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)",
"public void <init>(java.util.Map)",
@@ -88,7 +88,7 @@
"public"
],
"methods" : [
- "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)",
+ "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxRuntime)",
"public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)",
"protected final java.lang.String readExpressionFromFile(java.io.File)",
"protected com.yahoo.searchlib.rankingexpression.RankingExpression readExpressionFromFile(java.lang.String, com.yahoo.config.FileReference)",
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 74233853ae9..fd5306f9add 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.AbstractComponent;
@@ -32,8 +32,8 @@ public class ModelsEvaluator extends AbstractComponent {
RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig,
FileAcquirer fileAcquirer,
- OnnxEvaluatorCache cache) {
- this(new RankProfilesConfigImporter(fileAcquirer, cache), config, constantsConfig, expressionsConfig, onnxModelsConfig);
+ OnnxRuntime onnx) {
+ this(new RankProfilesConfigImporter(fileAcquirer, onnx), config, constantsConfig, expressionsConfig, onnxModelsConfig);
}
public ModelsEvaluator(RankProfilesConfig config,
@@ -41,7 +41,7 @@ public class ModelsEvaluator extends AbstractComponent {
RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig,
FileAcquirer fileAcquirer) {
- this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxEvaluatorCache());
+ this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxRuntime());
}
public ModelsEvaluator(RankProfilesConfigImporter importer,
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index 73c5eb36539..cf97c20e881 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -2,8 +2,8 @@
package ai.vespa.models.evaluation;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -48,12 +48,12 @@ class OnnxModel implements AutoCloseable {
final List<OutputSpec> outputSpecs = new ArrayList<>();
void addInputMapping(String onnxName, String source) {
- if (referencedEvaluator != null)
+ if (evaluator != null)
throw new IllegalStateException("input mapping must be added before load()");
inputSpecs.add(new InputSpec(onnxName, source));
}
void addOutputMapping(String onnxName, String outputAs) {
- if (referencedEvaluator != null)
+ if (evaluator != null)
throw new IllegalStateException("output mapping must be added before load()");
outputSpecs.add(new OutputSpec(onnxName, outputAs));
}
@@ -61,15 +61,15 @@ class OnnxModel implements AutoCloseable {
private final String name;
private final File modelFile;
private final OnnxEvaluatorOptions options;
- private final OnnxEvaluatorCache cache;
+ private final OnnxRuntime onnx;
- private OnnxEvaluatorCache.ReferencedEvaluator referencedEvaluator;
+ private OnnxEvaluator evaluator;
- OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxEvaluatorCache cache) {
+ OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxRuntime onnx) {
this.name = name;
this.modelFile = modelFile;
this.options = options;
- this.cache = cache;
+ this.onnx = onnx;
}
public String name() {
@@ -77,8 +77,8 @@ class OnnxModel implements AutoCloseable {
}
public void load() {
- if (referencedEvaluator == null) {
- referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options);
+ if (evaluator == null) {
+ evaluator = onnx.evaluatorOf(modelFile.getPath(), options);
fillInputTypes(evaluator().getInputs());
fillOutputTypes(evaluator().getOutputs());
}
@@ -178,11 +178,11 @@ class OnnxModel implements AutoCloseable {
}
private OnnxEvaluator evaluator() {
- if (referencedEvaluator == null) {
+ if (evaluator == null) {
throw new IllegalStateException("ONNX model has not been loaded.");
}
- return referencedEvaluator.evaluator();
+ return evaluator;
}
- @Override public void close() { referencedEvaluator.close(); }
+ @Override public void close() { evaluator.close(); }
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 6148287a536..8c520e87001 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,8 +1,8 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -47,11 +47,11 @@ import java.util.regex.Pattern;
public class RankProfilesConfigImporter {
private final FileAcquirer fileAcquirer;
- private final OnnxEvaluatorCache onnxEvaluatorCache;
+ private final OnnxRuntime onnx;
- public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxEvaluatorCache onnxEvaluatorCache) {
+ public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxRuntime onnx) {
this.fileAcquirer = fileAcquirer;
- this.onnxEvaluatorCache = onnxEvaluatorCache;
+ this.onnx = onnx;
}
/**
@@ -198,7 +198,7 @@ public class RankProfilesConfigImporter {
options.setInterOpThreads(onnxModelConfig.stateless_interop_threads());
options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads());
options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required());
- var m = new OnnxModel(name, file, options, onnxEvaluatorCache);
+ var m = new OnnxModel(name, file, options, onnx);
for (var spec : onnxModelConfig.input()) {
m.addInputMapping(spec.name(), spec.source());
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
index 992dae22aaf..0bee33be3cc 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
@@ -30,7 +30,7 @@ public class OnnxEvaluatorTest {
@Test
public void testOnnxEvaluation() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
ModelsEvaluator models = createModels();
assertTrue(models.models().containsKey("add_mul"));
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
index bfba5ae24c4..0dd3bd29a2c 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
@@ -25,7 +25,7 @@ public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesC
private final Path constantsPath;
public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) {
- super(fileAcquirer, new OnnxEvaluatorCache());
+ super(fileAcquirer, new OnnxRuntime());
this.constantsPath = constantsPath;
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 9b2b793212b..14da15f60d0 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.handler;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.models.evaluation.ModelsEvaluator;
import ai.vespa.models.evaluation.RankProfilesConfigImporterWithMockedConstants;
import com.yahoo.config.subscription.ConfigGetter;
@@ -323,7 +323,7 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSavedEvaluateSpecificFunction() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
Map<String, String> properties = new HashMap<>();
properties.put("input", inputTensor());
properties.put("format.tensors", "long");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
index 86f56e14e2d..856031da72f 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.handler;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -27,7 +27,7 @@ public class OnnxEvaluationHandlerTest {
@BeforeClass
static public void setUp() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
handler = new HandlerTester(createModels());
}
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 8f26758cf65..9bb60827a68 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -69,6 +69,12 @@
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>component</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>net.java.dev.jna</groupId>
<artifactId>jna</artifactId>
<scope>provided</scope>
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index 002350ce3cf..b0b4f871163 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -2,8 +2,10 @@ package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
-import com.yahoo.embedding.BertBaseEmbedderConfig;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.annotation.Inject;
+import com.yahoo.embedding.BertBaseEmbedderConfig;
+import com.yahoo.jdisc.AbstractResource;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.wordpiece.WordPieceEmbedder;
import com.yahoo.tensor.IndexedTensor;
@@ -28,7 +30,7 @@ import java.util.Map;
*
* @author lesters
*/
-public class BertBaseEmbedder implements Embedder {
+public class BertBaseEmbedder extends AbstractResource implements Embedder {
private final static int TOKEN_CLS = 101; // [CLS]
private final static int TOKEN_SEP = 102; // [SEP]
@@ -44,7 +46,7 @@ public class BertBaseEmbedder implements Embedder {
private final OnnxEvaluator evaluator;
@Inject
- public BertBaseEmbedder(BertBaseEmbedderConfig config) {
+ public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) {
maxTokens = config.transformerMaxTokens();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
@@ -58,7 +60,7 @@ public class BertBaseEmbedder implements Embedder {
options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads()));
tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build();
- evaluator = new OnnxEvaluator(config.transformerModel().toString(), options);
+ this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options);
validateModel();
}
@@ -100,6 +102,8 @@ public class BertBaseEmbedder implements Embedder {
return embedTokens(tokens, type);
}
+ @Override protected void destroy() { evaluator.close(); }
+
Tensor embedTokens(List<Integer> tokens, TensorType type) {
Tensor inputSequence = createTensorRepresentation(tokens, "d1");
Tensor attentionMask = createAttentionMask(inputSequence);
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 81150fe99b0..bad4bb5c9b3 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -3,22 +3,25 @@ package ai.vespa.embedding.huggingface;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.annotation.Inject;
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.jdisc.AbstractResource;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-import java.io.*;
+import java.io.IOException;
import java.nio.file.Paths;
-import java.util.*;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
-import org.slf4j.LoggerFactory;
-import org.slf4j.Logger;
-
-public class HuggingFaceEmbedder implements Embedder {
+public class HuggingFaceEmbedder extends AbstractResource implements Embedder {
private static final Logger LOG = LoggerFactory.getLogger(HuggingFaceEmbedder.class.getName());
@@ -30,7 +33,7 @@ public class HuggingFaceEmbedder implements Embedder {
private final OnnxEvaluator evaluator;
@Inject
- public HuggingFaceEmbedder(HuggingFaceEmbedderConfig config) throws IOException {
+ public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException {
maxTokens = config.transformerMaxTokens();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
@@ -48,7 +51,7 @@ public class HuggingFaceEmbedder implements Embedder {
LOG.info("Could not initialize the tokenizer");
throw new IOException("Could not initialize the tokenizer.");
}
- evaluator = new OnnxEvaluator(config.transformerModel().toString());
+ evaluator = onnx.evaluatorOf(config.transformerModel().toString());
validateModel();
}
@@ -83,6 +86,8 @@ public class HuggingFaceEmbedder implements Embedder {
return tokenIds;
}
+ @Override protected void destroy() { evaluator.close(); }
+
public List<Integer> longToInteger(long[] values) {
return Arrays.stream(values)
.boxed().map(Long::intValue)
diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java
index ed231a5e94c..a08e2006e2c 100644
--- a/model-integration/src/main/java/ai/vespa/llm/Generator.java
+++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java
@@ -2,7 +2,9 @@ package ai.vespa.llm;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.annotation.Inject;
+import com.yahoo.jdisc.AbstractResource;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.sentencepiece.SentencePieceEmbedder;
import com.yahoo.llm.GeneratorConfig;
@@ -25,7 +27,7 @@ import java.util.Map;
*
* @author lesters
*/
-public class Generator {
+public class Generator extends AbstractResource {
private final static int TOKEN_EOS = 1; // end of sequence
@@ -46,7 +48,7 @@ public class Generator {
private final OnnxEvaluator decoder;
@Inject
- public Generator(GeneratorConfig config) {
+ public Generator(OnnxRuntime onnx, GeneratorConfig config) {
// Set up tokenizer
tokenizer = new SentencePieceEmbedder.Builder(config.tokenizerModel().toString()).build();
tokenizerMaxTokens = config.tokenizerMaxTokens();
@@ -61,7 +63,7 @@ public class Generator {
encoderOptions.setInterOpThreads(modifyThreadCount(config.encoderOnnxInterOpThreads()));
encoderOptions.setIntraOpThreads(modifyThreadCount(config.encoderOnnxIntraOpThreads()));
- encoder = new OnnxEvaluator(config.encoderModel().toString(), encoderOptions);
+ encoder = onnx.evaluatorOf(config.encoderModel().toString(), encoderOptions);
// Set up decoder
decoderInputIdsName = config.decoderModelInputIdsName();
@@ -74,7 +76,7 @@ public class Generator {
decoderOptions.setInterOpThreads(modifyThreadCount(config.decoderOnnxInterOpThreads()));
decoderOptions.setIntraOpThreads(modifyThreadCount(config.decoderOnnxIntraOpThreads()));
- decoder = new OnnxEvaluator(config.decoderModel().toString(), decoderOptions);
+ decoder = onnx.evaluatorOf(config.decoderModel().toString(), decoderOptions);
validateModels();
}
@@ -99,6 +101,8 @@ public class Generator {
return generate(prompt, new GeneratorOptions());
}
+ @Override protected void destroy() { encoder.close(); decoder.close(); }
+
private String generateNotImplemented(GeneratorOptions options) {
throw new UnsupportedOperationException("Search method '" + options.getSearchMethod() + "' is currently not implemented");
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index c2d97e37074..7cdc27b6d63 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -5,9 +5,9 @@ package ai.vespa.modelintegration.evaluator;
import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
-import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime.ReferencedOrtSession;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -15,6 +15,8 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+import static ai.vespa.modelintegration.evaluator.OnnxRuntime.isCudaError;
+
/**
* Evaluates an ONNX Model by deferring to ONNX Runtime.
@@ -23,24 +25,18 @@ import java.util.Map;
*/
public class OnnxEvaluator implements AutoCloseable {
- private final OrtEnvironment environment;
- private final OrtSession session;
-
- public OnnxEvaluator(String modelPath) {
- this(modelPath, null);
- }
+ private final ReferencedOrtSession session;
- public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
- environment = OrtEnvironment.getEnvironment();
- session = createSession(modelPath, environment, options, true);
+ OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options, OnnxRuntime runtime) {
+ session = createSession(modelPath, runtime, options, true);
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
Map<String, OnnxTensor> onnxInputs = null;
try {
output = mapToInternalName(output);
- onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
- try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) {
+ onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), session.instance());
+ try (OrtSession.Result result = session.instance().run(onnxInputs, Collections.singleton(output))) {
return TensorConverter.toVespaTensor(result.get(0));
}
} catch (OrtException e) {
@@ -55,9 +51,9 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
Map<String, OnnxTensor> onnxInputs = null;
try {
- onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
+ onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), session.instance());
Map<String, Tensor> outputs = new HashMap<>();
- try (OrtSession.Result result = session.run(onnxInputs)) {
+ try (OrtSession.Result result = session.instance().run(onnxInputs)) {
for (Map.Entry<String, OnnxValue> output : result) {
String mapped = TensorConverter.asValidName(output.getKey());
outputs.put(mapped, TensorConverter.toVespaTensor(output.getValue()));
@@ -88,7 +84,7 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, IdAndType> getInputs() {
try {
- return toSpecMap(session.getInputInfo());
+ return toSpecMap(session.instance().getInputInfo());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
@@ -96,7 +92,7 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, IdAndType> getOutputs() {
try {
- return toSpecMap(session.getOutputInfo());
+ return toSpecMap(session.instance().getOutputInfo());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
@@ -104,7 +100,7 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, TensorType> getInputInfo() {
try {
- return TensorConverter.toVespaTypes(session.getInputInfo());
+ return TensorConverter.toVespaTypes(session.instance().getInputInfo());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
@@ -112,7 +108,7 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, TensorType> getOutputInfo() {
try {
- return TensorConverter.toVespaTypes(session.getOutputInfo());
+ return TensorConverter.toVespaTypes(session.instance().getOutputInfo());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
@@ -122,26 +118,26 @@ public class OnnxEvaluator implements AutoCloseable {
public void close() throws IllegalStateException {
try {
session.close();
- } catch (OrtException e) {
+ } catch (UncheckedOrtException e) {
throw new IllegalStateException("Failed to close ONNX session", e);
} catch (IllegalStateException e) {
throw new IllegalStateException("Already closed", e);
}
}
- private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) {
+ private static ReferencedOrtSession createSession(String modelPath, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) {
if (options == null) {
options = new OnnxEvaluatorOptions();
}
try {
- return environment.createSession(modelPath, options.getOptions(tryCuda && options.requestingGpu()));
+ return runtime.acquireSession(modelPath, options, tryCuda && options.requestingGpu());
} catch (OrtException e) {
if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
throw new IllegalArgumentException("No such file: " + modelPath);
}
if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) {
// Failed in CUDA native code, but GPU device is optional, so we can proceed without it
- return createSession(modelPath, environment, options, false);
+ return createSession(modelPath, runtime, options, false);
}
if (isCudaError(e)) {
throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e);
@@ -150,34 +146,8 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
- private static boolean isCudaError(OrtException e) {
- return switch (e.getCode()) {
- case ORT_FAIL -> e.getMessage().contains("cudaError");
- case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA");
- default -> false;
- };
- }
-
- public static boolean isRuntimeAvailable() {
- return isRuntimeAvailable("");
- }
-
- public static boolean isRuntimeAvailable(String modelPath) {
- try {
- new OnnxEvaluator(modelPath);
- return true;
- } catch (IllegalArgumentException e) {
- if (e.getMessage().equals("No such file: ")) {
- return true;
- }
- return false;
- } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) {
- return false;
- }
- }
-
private String mapToInternalName(String outputName) throws OrtException {
- var info = session.getOutputInfo();
+ var info = session.instance().getOutputInfo();
var internalNames = info.keySet();
for (String name : internalNames) {
if (name.equals(outputName)) {
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java
deleted file mode 100644
index b92ce24a6b4..00000000000
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java
+++ /dev/null
@@ -1,88 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.modelintegration.evaluator;
-
-import com.yahoo.jdisc.AbstractResource;
-import com.yahoo.jdisc.ReferencedResource;
-import com.yahoo.jdisc.ResourceReference;
-
-import javax.inject.Inject;
-import java.util.HashMap;
-import java.util.Map;
-
-/**
- * Caches instances of {@link OnnxEvaluator}.
- *
- * @author bjorncs
- */
-public class OnnxEvaluatorCache {
-
- // For mocking OnnxEvaluator in tests
- @FunctionalInterface interface OnnxEvaluatorFactory { OnnxEvaluator create(String path, OnnxEvaluatorOptions opts); }
-
- private final Object monitor = new Object();
- private final Map<Id, SharedEvaluator> cache = new HashMap<>();
- private final OnnxEvaluatorFactory factory;
-
- @Inject public OnnxEvaluatorCache() { this(OnnxEvaluator::new); }
-
- OnnxEvaluatorCache(OnnxEvaluatorFactory factory) { this.factory = factory; }
-
- public ReferencedEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
- synchronized (monitor) {
- var id = new Id(modelPath, options);
- var sharedInstance = cache.get(id);
- if (sharedInstance == null) {
- return newInstance(id);
- } else {
- ResourceReference reference;
- try {
- // refer() may throw if last reference was just released, but instance has not yet been removed from cache
- reference = sharedInstance.refer(id);
- } catch (IllegalStateException e) {
- return newInstance(id);
- }
- return new ReferencedEvaluator(sharedInstance, reference);
- }
- }
- }
-
- int size() { return cache.size(); }
-
- private ReferencedEvaluator newInstance(Id id) {
- var evaluator = new SharedEvaluator(id, factory.create(id.modelPath, id.options));
- cache.put(id, evaluator);
- var referenced = new ReferencedEvaluator(evaluator, evaluator.refer(id));
- // Release "main" reference to ensure that evaluator is destroyed when last external reference is released
- evaluator.release();
- return referenced;
- }
-
- // We assume options are never modified after being passed to cache
- record Id(String modelPath, OnnxEvaluatorOptions options) {}
-
- public class ReferencedEvaluator extends ReferencedResource<SharedEvaluator> {
- ReferencedEvaluator(SharedEvaluator resource, ResourceReference reference) { super(resource, reference); }
-
- public OnnxEvaluator evaluator() { return getResource().instance(); }
- }
-
- public class SharedEvaluator extends AbstractResource {
- private final Id id;
- private final OnnxEvaluator instance;
-
- private SharedEvaluator(Id id, OnnxEvaluator instance) {
- this.id = id;
- this.instance = instance;
- }
-
- public OnnxEvaluator instance() { return instance; }
-
- @Override
- protected void destroy() {
- synchronized (OnnxEvaluatorCache.this) { cache.remove(id); }
- instance.close();
- }
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
new file mode 100644
index 00000000000..42830041c02
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
@@ -0,0 +1,170 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import com.yahoo.component.AbstractComponent;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.jdisc.ResourceReference;
+import com.yahoo.jdisc.refcount.DebugReferencesWithStack;
+import com.yahoo.jdisc.refcount.References;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import static com.yahoo.yolean.Exceptions.throwUnchecked;
+
+/**
+ * Provides ONNX runtime environment with session management.
+ *
+ * @author bjorncs
+ */
+public class OnnxRuntime extends AbstractComponent {
+
+ // For unit testing
+ @FunctionalInterface interface OrtSessionFactory {
+ OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException;
+ }
+
+ private static final Logger log = Logger.getLogger(OnnxRuntime.class.getName());
+
+ private static final OrtEnvironmentResult ortEnvironment = getOrtEnvironment();
+ private static final OrtSessionFactory defaultFactory = (path, opts) -> ortEnvironment().createSession(path, opts);
+
+ private final Object monitor = new Object();
+ private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>();
+ private final OrtSessionFactory factory;
+
+ @Inject public OnnxRuntime() { this(defaultFactory); }
+
+ OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; }
+
+ public OnnxEvaluator evaluatorOf(String modelPath) {
+ return new OnnxEvaluator(modelPath, null, this);
+ }
+
+ public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
+ return new OnnxEvaluator(modelPath, options, this);
+ }
+
+ public static OrtEnvironment ortEnvironment() {
+ if (ortEnvironment.env() != null) return ortEnvironment.env();
+ throw throwUnchecked(ortEnvironment.failure());
+ }
+
+ @Override
+ public void deconstruct() {
+ synchronized (monitor) {
+ sessions.forEach((id, sharedSession) -> {
+ int hash = System.identityHashCode(sharedSession.session());
+ var refs = sharedSession.references();
+ log.warning("Closing leaked session %s (%s) with %d outstanding references:\n%s"
+ .formatted(id, hash, refs.referenceCount(), refs.currentState()));
+ try {
+ sharedSession.session().close();
+ } catch (Exception e) {
+ log.log(Level.WARNING, "Failed to close session %s (%s)".formatted(id, hash), e);
+ }
+ });
+ sessions.clear();
+ }
+ }
+
+ private static OrtEnvironmentResult getOrtEnvironment() {
+ try {
+ return new OrtEnvironmentResult(OrtEnvironment.getEnvironment(), null);
+ } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) {
+ log.log(Level.FINE, e, () -> "Failed to load ONNX runtime");
+ return new OrtEnvironmentResult(null, e);
+ }
+ }
+
+ public static boolean isRuntimeAvailable() { return ortEnvironment.env() != null; }
+ public static boolean isRuntimeAvailable(String modelPath) {
+ if (!isRuntimeAvailable()) return false;
+ try {
+ // Expensive way of checking if runtime is available as it incurs the cost of loading the model if successful
+ defaultFactory.create(modelPath, new OnnxEvaluatorOptions().getOptions(false));
+ return true;
+ } catch (OrtException e) {
+ return e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE;
+ } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) {
+ return false;
+ }
+ }
+
+ static boolean isCudaError(OrtException e) {
+ return switch (e.getCode()) {
+ case ORT_FAIL -> e.getMessage().contains("cudaError");
+ case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA");
+ default -> false;
+ };
+ }
+
+ ReferencedOrtSession acquireSession(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException {
+ var sessionId = new OrtSessionId(modelPath, options, loadCuda);
+ synchronized (monitor) {
+ var sharedSession = sessions.get(sessionId);
+ if (sharedSession != null) {
+ return sharedSession.newReference();
+ }
+ }
+
+ // Note: identical models loaded simultaneously will result in duplicate session instances
+ var session = factory.create(modelPath, options.getOptions(loadCuda));
+ log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(session)));
+
+ var sharedSession = new SharedOrtSession(sessionId, session);
+ var referencedSession = sharedSession.newReference();
+ synchronized (monitor) { sessions.put(sessionId, sharedSession); }
+ sharedSession.references().release(); // Release initial reference
+ return referencedSession;
+ }
+
+ int sessionsCached() { synchronized(monitor) { return sessions.size(); } }
+
+ public static class ReferencedOrtSession implements AutoCloseable {
+ private final OrtSession instance;
+ private final ResourceReference ref;
+
+ public ReferencedOrtSession(OrtSession instance, ResourceReference ref) {
+ this.instance = instance;
+ this.ref = ref;
+ }
+
+ public OrtSession instance() { return instance; }
+ @Override public void close() { ref.close(); }
+ }
+
+ // Assumes options are never modified after being stored in `onnxSessions`
+ record OrtSessionId(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) {}
+
+ record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {}
+
+ private class SharedOrtSession {
+ private final OrtSessionId id;
+ private final OrtSession session;
+ private final References refs = new DebugReferencesWithStack(this::close);
+
+ SharedOrtSession(OrtSessionId id, OrtSession session) {
+ this.id = id;
+ this.session = session;
+ }
+
+ ReferencedOrtSession newReference() { return new ReferencedOrtSession(session, refs.refer(id)); }
+ References references() { return refs; }
+ OrtSession session() { return session; }
+
+ void close() {
+ try {
+ synchronized (OnnxRuntime.this.monitor) { sessions.remove(id); }
+ log.fine(() -> "Closing session (%s)".formatted(System.identityHashCode(session)));
+ session.close();
+ } catch (OrtException e) { throw new UncheckedOrtException(e);}
+ }
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java
new file mode 100644
index 00000000000..1f2c8ba2cf7
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java
@@ -0,0 +1,15 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OrtException;
+
+/**
+ * @author bjorncs
+ */
+public class UncheckedOrtException extends RuntimeException {
+
+ public UncheckedOrtException(Throwable e) { super(e.getMessage(), e); }
+
+ @Override public synchronized OrtException getCause() { return (OrtException) super.getCause(); }
+}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index b06a54d68bb..329b87cacd1 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -1,13 +1,12 @@
package ai.vespa.embedding;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
-import java.lang.IllegalArgumentException;
import java.util.List;
import static org.junit.Assert.assertEquals;
@@ -20,12 +19,12 @@ public class BertBaseEmbedderTest {
public void testEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
String modelPath = "src/test/models/onnx/transformer/dummy_transformer.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
- BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
+ BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");
List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer
@@ -39,13 +38,13 @@ public class BertBaseEmbedderTest {
public void testEmbedderWithoutTokenTypeIdsName() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
builder.transformerTokenTypeIds("");
- BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
+ BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");
List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer
@@ -59,14 +58,18 @@ public class BertBaseEmbedderTest {
public void testEmbedderWithoutTokenTypeIdsNameButWithConfig() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
// we did not configured BertBaseEmbedder to accept missing token type ids
// so we expect ctor to throw
- assertThrows(IllegalArgumentException.class, () -> { new BertBaseEmbedder(builder.build()); });
+ assertThrows(IllegalArgumentException.class, () -> { newBertBaseEmbedder(builder.build()); });
+ }
+
+ private static BertBaseEmbedder newBertBaseEmbedder(BertBaseEmbedderConfig cfg) {
+ return new BertBaseEmbedder(new OnnxRuntime(), cfg);
}
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
index c67b6b0dcab..0ff9acc9a69 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
@@ -1,19 +1,5 @@
package ai.vespa.embedding.huggingface;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
-import com.yahoo.config.ModelReference;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.junit.Test;
-
-import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
-
-import java.io.IOException;
-import java.util.List;
-
-import static org.junit.Assume.assumeTrue;
-import static org.junit.Assert.assertEquals;
-
public class HuggingFaceEmbedderTest {
/*
@Test
@@ -21,7 +7,7 @@ public class HuggingFaceEmbedderTest {
String modelPath = "src/test/models/hf/model.onnx";
String tokenizerPath = "src/test/models/hf/tokenizer.json";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
builder.tokenizerPath(ModelReference.valueOf(tokenizerPath));
diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
index 733430aa10d..c22902b344f 100644
--- a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
+++ b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
@@ -1,6 +1,6 @@
package ai.vespa.llm;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.llm.GeneratorConfig;
import org.junit.Test;
@@ -15,13 +15,13 @@ public class GeneratorTest {
String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model";
String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx";
String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(encoderModelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath));
GeneratorConfig.Builder builder = new GeneratorConfig.Builder();
builder.tokenizerModel(ModelReference.valueOf(vocabPath));
builder.encoderModel(ModelReference.valueOf(encoderModelPath));
builder.decoderModel(ModelReference.valueOf(decoderModelPath));
- Generator generator = new Generator(builder.build());
+ Generator generator = newGenerator(builder.build());
GeneratorOptions options = new GeneratorOptions();
options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY);
@@ -33,4 +33,8 @@ public class GeneratorTest {
assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result);
}
+ private static Generator newGenerator(GeneratorConfig cfg) {
+ return new Generator(new OnnxRuntime(), cfg);
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
deleted file mode 100644
index acce660f466..00000000000
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.modelintegration.evaluator;
-
-import org.junit.jupiter.api.Test;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotSame;
-import static org.junit.jupiter.api.Assertions.assertSame;
-import static org.mockito.Mockito.mock;
-
-/**
- * @author bjorncs
- */
-class OnnxEvaluatorCacheTest {
-
- @Test
- void reuses_instance_while_in_use() {
- var cache = new OnnxEvaluatorCache((__, ___) -> mock(OnnxEvaluator.class));
- var referencedEvaluator1 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
- var referencedEvaluator2 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
- var referencedEvaluator3 = cache.evaluatorOf("model2", new OnnxEvaluatorOptions());
- assertSame(referencedEvaluator1.evaluator(), referencedEvaluator2.evaluator());
- assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator3.evaluator());
- assertEquals(2, cache.size());
- referencedEvaluator1.close();
- referencedEvaluator2.close();
- assertEquals(1, cache.size());
- referencedEvaluator3.close();
- assertEquals(0, cache.size());
- var referencedEvaluator4 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
- assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator4.evaluator());
- assertEquals(1, cache.size());
- referencedEvaluator4.close();
- assertEquals(0, cache.size());
- }
-
-} \ No newline at end of file
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 83f355821e5..5aba54de11b 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -5,23 +5,31 @@ package ai.vespa.modelintegration.evaluator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import org.junit.jupiter.api.BeforeAll;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assume.assumeTrue;
+import static org.junit.Assume.assumeNotNull;
/**
* @author lesters
*/
public class OnnxEvaluatorTest {
+ private static OnnxRuntime runtime;
+
+ @BeforeAll
+ public static void beforeAll() {
+ if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime();
+ }
+
@Test
public void testSimpleModel() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx");
+ assumeNotNull(runtime);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx");
// Input types
Map<String, TensorType> inputTypes = evaluator.getInputInfo();
@@ -45,8 +53,8 @@ public class OnnxEvaluatorTest {
@Test
public void testBatchDimension() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx");
+ assumeNotNull(runtime);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx");
// Input types
Map<String, TensorType> inputTypes = evaluator.getInputInfo();
@@ -64,7 +72,7 @@ public class OnnxEvaluatorTest {
@Test
public void testMatMul() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeNotNull(runtime);
String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]";
String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]";
String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]";
@@ -73,7 +81,7 @@ public class OnnxEvaluatorTest {
@Test
public void testTypes() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeNotNull(runtime);
assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
@@ -86,8 +94,8 @@ public class OnnxEvaluatorTest {
@Test
public void testNotIdentifiers() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/badnames.onnx");
+ assumeNotNull(runtime);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx");
var inputInfo = evaluator.getInputInfo();
var outputInfo = evaluator.getOutputInfo();
for (var entry : inputInfo.entrySet()) {
@@ -152,7 +160,7 @@ public class OnnxEvaluatorTest {
}
private void assertEvaluate(String model, String output, String... input) {
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/" + model);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model);
Map<String, Tensor> inputs = new HashMap<>();
for (int i = 0; i < input.length; ++i) {
inputs.put("input" + (i+1), Tensor.from(input[i]));
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
new file mode 100644
index 00000000000..81b1237e770
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
@@ -0,0 +1,48 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+
+/**
+ * @author bjorncs
+ */
+class OnnxRuntimeTest {
+
+ @Test
+ void reuses_sessions_while_active() throws OrtException {
+ var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class));
+ var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
+ var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
+ var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false);
+ assertSame(session1.instance(), session2.instance());
+ assertNotSame(session1.instance(), session3.instance());
+ assertEquals(2, runtime.sessionsCached());
+
+ session1.close();
+ session2.close();
+ assertEquals(1, runtime.sessionsCached());
+ verify(session1.instance()).close();
+ verify(session3.instance(), never()).close();
+
+ session3.close();
+ assertEquals(0, runtime.sessionsCached());
+ verify(session3.instance()).close();
+
+ var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
+ assertNotSame(session1.instance(), session4.instance());
+ assertEquals(1, runtime.sessionsCached());
+ session4.close();
+ assertEquals(0, runtime.sessionsCached());
+ verify(session4.instance()).close();
+ }
+} \ No newline at end of file