diff options
Diffstat (limited to 'model-evaluation')
8 files changed, 31 insertions, 31 deletions
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index 9fd25ac115b..667712d0daa 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -70,7 +70,7 @@ "public" ], "methods" : [ - "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)", + "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxRuntime)", "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)", "public void <init>(ai.vespa.models.evaluation.RankProfilesConfigImporter, com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)", "public void <init>(java.util.Map)", @@ -88,7 +88,7 @@ "public" ], "methods" : [ - "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)", + "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxRuntime)", "public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)", "protected final java.lang.String readExpressionFromFile(java.io.File)", "protected com.yahoo.searchlib.rankingexpression.RankingExpression readExpressionFromFile(java.lang.String, com.yahoo.config.FileReference)", diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 74233853ae9..fd5306f9add 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; import com.yahoo.component.AbstractComponent; @@ -32,8 +32,8 @@ public class ModelsEvaluator extends AbstractComponent { RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer, - OnnxEvaluatorCache cache) { - this(new RankProfilesConfigImporter(fileAcquirer, cache), config, constantsConfig, expressionsConfig, onnxModelsConfig); + OnnxRuntime onnx) { + this(new RankProfilesConfigImporter(fileAcquirer, onnx), config, constantsConfig, expressionsConfig, onnxModelsConfig); } public ModelsEvaluator(RankProfilesConfig config, @@ -41,7 +41,7 @@ public class ModelsEvaluator extends AbstractComponent { RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer) { - this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxEvaluatorCache()); + this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxRuntime()); } public ModelsEvaluator(RankProfilesConfigImporter importer, diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index 73c5eb36539..cf97c20e881 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -2,8 +2,8 @@ package ai.vespa.models.evaluation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; -import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -48,12 +48,12 @@ class OnnxModel implements AutoCloseable { final List<OutputSpec> outputSpecs = new ArrayList<>(); void addInputMapping(String onnxName, String source) { - if (referencedEvaluator != null) + if (evaluator != null) throw new IllegalStateException("input mapping must be added before load()"); inputSpecs.add(new InputSpec(onnxName, source)); } void addOutputMapping(String onnxName, String outputAs) { - if (referencedEvaluator != null) + if (evaluator != null) throw new IllegalStateException("output mapping must be added before load()"); outputSpecs.add(new OutputSpec(onnxName, outputAs)); } @@ -61,15 +61,15 @@ class OnnxModel implements AutoCloseable { private final String name; private final File modelFile; private final OnnxEvaluatorOptions options; - private final OnnxEvaluatorCache cache; + private final OnnxRuntime onnx; - private OnnxEvaluatorCache.ReferencedEvaluator referencedEvaluator; + private OnnxEvaluator evaluator; - OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxEvaluatorCache cache) { + OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxRuntime onnx) { this.name = name; this.modelFile = modelFile; this.options = options; - this.cache = cache; + this.onnx = onnx; } public String name() { @@ -77,8 +77,8 @@ class OnnxModel implements AutoCloseable { } public void load() { - if (referencedEvaluator == null) { - referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options); + if (evaluator == null) { + evaluator = onnx.evaluatorOf(modelFile.getPath(), options); fillInputTypes(evaluator().getInputs()); fillOutputTypes(evaluator().getOutputs()); } @@ -178,11 +178,11 @@ class OnnxModel implements AutoCloseable { } private OnnxEvaluator evaluator() { - if (referencedEvaluator == null) { + if (evaluator == null) { throw new IllegalStateException("ONNX model has not been loaded."); } - return referencedEvaluator.evaluator(); + return evaluator; } - @Override public void close() { referencedEvaluator.close(); } + @Override public void close() { evaluator.close(); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 6148287a536..8c520e87001 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,8 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.collections.Pair; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -47,11 +47,11 @@ import java.util.regex.Pattern; public class RankProfilesConfigImporter { private final FileAcquirer fileAcquirer; - private final OnnxEvaluatorCache onnxEvaluatorCache; + private final OnnxRuntime onnx; - public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxEvaluatorCache onnxEvaluatorCache) { + public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxRuntime onnx) { this.fileAcquirer = fileAcquirer; - this.onnxEvaluatorCache = onnxEvaluatorCache; + this.onnx = onnx; } /** @@ -198,7 +198,7 @@ public class RankProfilesConfigImporter { options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required()); - var m = new OnnxModel(name, file, options, onnxEvaluatorCache); + var m = new OnnxModel(name, file, options, onnx); for (var spec : onnxModelConfig.input()) { m.addInputMapping(spec.name(), spec.source()); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java index 992dae22aaf..0bee33be3cc 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; @@ -30,7 +30,7 @@ public class OnnxEvaluatorTest { @Test public void testOnnxEvaluation() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); ModelsEvaluator models = createModels(); assertTrue(models.models().containsKey("add_mul")); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java index bfba5ae24c4..0dd3bd29a2c 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.io.GrowableByteBuffer; @@ -25,7 +25,7 @@ public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesC private final Path constantsPath; public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) { - super(fileAcquirer, new OnnxEvaluatorCache()); + super(fileAcquirer, new OnnxRuntime()); this.constantsPath = constantsPath; } diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 9b2b793212b..14da15f60d0 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.handler; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import ai.vespa.models.evaluation.ModelsEvaluator; import ai.vespa.models.evaluation.RankProfilesConfigImporterWithMockedConstants; import com.yahoo.config.subscription.ConfigGetter; @@ -323,7 +323,7 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSavedEvaluateSpecificFunction() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); Map<String, String> properties = new HashMap<>(); properties.put("input", inputTensor()); properties.put("format.tensors", "long"); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java index 86f56e14e2d..856031da72f 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.handler; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -27,7 +27,7 @@ public class OnnxEvaluationHandlerTest { @BeforeClass static public void setUp() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); handler = new HandlerTester(createModels()); } |