summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 17:02:23 +0100
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 18:13:08 +0100
commit5271b5d7241aa2aa2538b2072b8cae9b8f3d689a (patch)
tree12f025b12e86e5f9490b74dd2cae68283f779e67 /model-evaluation
parent6b40c6053b8542ae20a5bbe669f84f2d478fd697 (diff)
Replace `OnnxEvaluatorCache` with OnnxRuntime
Require an `OnnxRuntime` instance to create `OnnxEvaluator` instances. Cache underlying `OrtSession` instead of `OnnxEvaluator`. Move static helpers for checking Onnx runtime availability from `OnnxEvaluator` to `OnnxRuntime`.
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/abi-spec.json4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java8
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java24
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java10
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java4
8 files changed, 31 insertions, 31 deletions
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index 9fd25ac115b..667712d0daa 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -70,7 +70,7 @@
"public"
],
"methods" : [
- "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)",
+ "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxRuntime)",
"public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
"public void <init>(ai.vespa.models.evaluation.RankProfilesConfigImporter, com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)",
"public void <init>(java.util.Map)",
@@ -88,7 +88,7 @@
"public"
],
"methods" : [
- "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)",
+ "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxRuntime)",
"public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)",
"protected final java.lang.String readExpressionFromFile(java.io.File)",
"protected com.yahoo.searchlib.rankingexpression.RankingExpression readExpressionFromFile(java.lang.String, com.yahoo.config.FileReference)",
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 74233853ae9..fd5306f9add 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.AbstractComponent;
@@ -32,8 +32,8 @@ public class ModelsEvaluator extends AbstractComponent {
RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig,
FileAcquirer fileAcquirer,
- OnnxEvaluatorCache cache) {
- this(new RankProfilesConfigImporter(fileAcquirer, cache), config, constantsConfig, expressionsConfig, onnxModelsConfig);
+ OnnxRuntime onnx) {
+ this(new RankProfilesConfigImporter(fileAcquirer, onnx), config, constantsConfig, expressionsConfig, onnxModelsConfig);
}
public ModelsEvaluator(RankProfilesConfig config,
@@ -41,7 +41,7 @@ public class ModelsEvaluator extends AbstractComponent {
RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig,
FileAcquirer fileAcquirer) {
- this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxEvaluatorCache());
+ this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxRuntime());
}
public ModelsEvaluator(RankProfilesConfigImporter importer,
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index 73c5eb36539..cf97c20e881 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -2,8 +2,8 @@
package ai.vespa.models.evaluation;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -48,12 +48,12 @@ class OnnxModel implements AutoCloseable {
final List<OutputSpec> outputSpecs = new ArrayList<>();
void addInputMapping(String onnxName, String source) {
- if (referencedEvaluator != null)
+ if (evaluator != null)
throw new IllegalStateException("input mapping must be added before load()");
inputSpecs.add(new InputSpec(onnxName, source));
}
void addOutputMapping(String onnxName, String outputAs) {
- if (referencedEvaluator != null)
+ if (evaluator != null)
throw new IllegalStateException("output mapping must be added before load()");
outputSpecs.add(new OutputSpec(onnxName, outputAs));
}
@@ -61,15 +61,15 @@ class OnnxModel implements AutoCloseable {
private final String name;
private final File modelFile;
private final OnnxEvaluatorOptions options;
- private final OnnxEvaluatorCache cache;
+ private final OnnxRuntime onnx;
- private OnnxEvaluatorCache.ReferencedEvaluator referencedEvaluator;
+ private OnnxEvaluator evaluator;
- OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxEvaluatorCache cache) {
+ OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxRuntime onnx) {
this.name = name;
this.modelFile = modelFile;
this.options = options;
- this.cache = cache;
+ this.onnx = onnx;
}
public String name() {
@@ -77,8 +77,8 @@ class OnnxModel implements AutoCloseable {
}
public void load() {
- if (referencedEvaluator == null) {
- referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options);
+ if (evaluator == null) {
+ evaluator = onnx.evaluatorOf(modelFile.getPath(), options);
fillInputTypes(evaluator().getInputs());
fillOutputTypes(evaluator().getOutputs());
}
@@ -178,11 +178,11 @@ class OnnxModel implements AutoCloseable {
}
private OnnxEvaluator evaluator() {
- if (referencedEvaluator == null) {
+ if (evaluator == null) {
throw new IllegalStateException("ONNX model has not been loaded.");
}
- return referencedEvaluator.evaluator();
+ return evaluator;
}
- @Override public void close() { referencedEvaluator.close(); }
+ @Override public void close() { evaluator.close(); }
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 6148287a536..8c520e87001 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,8 +1,8 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -47,11 +47,11 @@ import java.util.regex.Pattern;
public class RankProfilesConfigImporter {
private final FileAcquirer fileAcquirer;
- private final OnnxEvaluatorCache onnxEvaluatorCache;
+ private final OnnxRuntime onnx;
- public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxEvaluatorCache onnxEvaluatorCache) {
+ public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxRuntime onnx) {
this.fileAcquirer = fileAcquirer;
- this.onnxEvaluatorCache = onnxEvaluatorCache;
+ this.onnx = onnx;
}
/**
@@ -198,7 +198,7 @@ public class RankProfilesConfigImporter {
options.setInterOpThreads(onnxModelConfig.stateless_interop_threads());
options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads());
options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required());
- var m = new OnnxModel(name, file, options, onnxEvaluatorCache);
+ var m = new OnnxModel(name, file, options, onnx);
for (var spec : onnxModelConfig.input()) {
m.addInputMapping(spec.name(), spec.source());
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
index 992dae22aaf..0bee33be3cc 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
@@ -30,7 +30,7 @@ public class OnnxEvaluatorTest {
@Test
public void testOnnxEvaluation() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
ModelsEvaluator models = createModels();
assertTrue(models.models().containsKey("add_mul"));
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
index bfba5ae24c4..0dd3bd29a2c 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
@@ -25,7 +25,7 @@ public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesC
private final Path constantsPath;
public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) {
- super(fileAcquirer, new OnnxEvaluatorCache());
+ super(fileAcquirer, new OnnxRuntime());
this.constantsPath = constantsPath;
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 9b2b793212b..14da15f60d0 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.handler;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.models.evaluation.ModelsEvaluator;
import ai.vespa.models.evaluation.RankProfilesConfigImporterWithMockedConstants;
import com.yahoo.config.subscription.ConfigGetter;
@@ -323,7 +323,7 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSavedEvaluateSpecificFunction() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
Map<String, String> properties = new HashMap<>();
properties.put("input", inputTensor());
properties.put("format.tensors", "long");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
index 86f56e14e2d..856031da72f 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.handler;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -27,7 +27,7 @@ public class OnnxEvaluationHandlerTest {
@BeforeClass
static public void setUp() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
handler = new HandlerTester(createModels());
}