diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-02-22 11:51:06 +0100 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-02-22 12:04:35 +0100 |
commit | 7d69590e78f7e29dd7288a401e71732211a3b5dd (patch) | |
tree | 74286f892f873ee0309a72529447f2e575cbb15e /model-evaluation | |
parent | c5513d25475c78ce6a3ecd5e03b278f3eebca481 (diff) |
Cache Onnx model instances
Manage lifecycle of OnnxEvaluator instances explicitly to allow
instances to be cached without use WeakHashmap/finalizers.
Inject shared Onnx model cache in ModelsEvaluator.
Diffstat (limited to 'model-evaluation')
6 files changed, 44 insertions, 17 deletions
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index a5bda6e1c21..9fd25ac115b 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -47,7 +47,9 @@ }, "ai.vespa.models.evaluation.Model" : { "superClass" : "java.lang.Object", - "interfaces" : [ ], + "interfaces" : [ + "java.lang.AutoCloseable" + ], "attributes" : [ "public" ], @@ -56,7 +58,8 @@ "public java.lang.String name()", "public java.util.List functions()", "public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String[])", - "public java.lang.String toString()" + "public java.lang.String toString()", + "public void close()" ], "fields" : [ ] }, @@ -67,12 +70,14 @@ "public" ], "methods" : [ + "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)", "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)", "public void <init>(ai.vespa.models.evaluation.RankProfilesConfigImporter, com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)", "public void <init>(java.util.Map)", "public java.util.Map models()", "public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String, java.lang.String[])", - "public ai.vespa.models.evaluation.Model requireModel(java.lang.String)" + "public ai.vespa.models.evaluation.Model requireModel(java.lang.String)", + "public void deconstruct()" ], "fields" : [ ] }, @@ -83,7 +88,7 @@ "public" ], "methods" : [ - "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer)", + "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)", "public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)", "protected final java.lang.String readExpressionFromFile(java.io.File)", "protected com.yahoo.searchlib.rankingexpression.RankingExpression readExpressionFromFile(java.lang.String, com.yahoo.config.FileReference)", diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index d66d0330ea6..c317cdc5922 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -10,7 +10,6 @@ import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -22,7 +21,7 @@ import java.util.stream.Collectors; * @author bratseth */ @Beta -public class Model { +public class Model implements AutoCloseable { /** The prefix generated by model-integration/../IntermediateOperation */ private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; @@ -43,6 +42,8 @@ public class Model { private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); + private final List<Runnable> closeActions; + /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection<ExpressionFunction> functions) { this(name, @@ -101,6 +102,7 @@ public class Model { // Optimize functions this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream() .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName()))))); + this.closeActions = onnxModels.stream().map(o -> (Runnable)o::close).toList(); } /** Returns an optimized version of the given function */ @@ -223,4 +225,5 @@ public class Model { @Override public String toString() { return "model '" + name + "'"; } + @Override public void close() { closeActions.forEach(Runnable::run); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 28b613ca281..74233853ae9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; import com.yahoo.component.AbstractComponent; @@ -30,8 +31,17 @@ public class ModelsEvaluator extends AbstractComponent { RankingConstantsConfig constantsConfig, RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig, + FileAcquirer fileAcquirer, + OnnxEvaluatorCache cache) { + this(new RankProfilesConfigImporter(fileAcquirer, cache), config, constantsConfig, expressionsConfig, onnxModelsConfig); + } + + public ModelsEvaluator(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + RankingExpressionsConfig expressionsConfig, + OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer) { - this(new RankProfilesConfigImporter(fileAcquirer), config, constantsConfig, expressionsConfig, onnxModelsConfig); + this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxEvaluatorCache()); } public ModelsEvaluator(RankProfilesConfigImporter importer, @@ -69,4 +79,5 @@ public class ModelsEvaluator extends AbstractComponent { return model; } + @Override public void deconstruct() { models.values().forEach(Model::close); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index 19a9a1dccd5..ac66b1151f3 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -14,18 +15,20 @@ import java.util.Map; * * @author lesters */ -class OnnxModel { +class OnnxModel implements AutoCloseable { private final String name; private final File modelFile; private final OnnxEvaluatorOptions options; + private final OnnxEvaluatorCache cache; - private OnnxEvaluator evaluator; + private OnnxEvaluatorCache.ReferencedEvaluator referencedEvaluator; - OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options) { + OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxEvaluatorCache cache) { this.name = name; this.modelFile = modelFile; this.options = options; + this.cache = cache; } public String name() { @@ -33,8 +36,8 @@ class OnnxModel { } public void load() { - if (evaluator == null) { - evaluator = new OnnxEvaluator(modelFile.getPath(), options); + if (referencedEvaluator == null) { + referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options); } } @@ -51,10 +54,11 @@ class OnnxModel { } private OnnxEvaluator evaluator() { - if (evaluator == null) { + if (referencedEvaluator == null) { throw new IllegalStateException("ONNX model has not been loaded."); } - return evaluator; + return referencedEvaluator.evaluator(); } + @Override public void close() { referencedEvaluator.close(); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index e8aae24ca9e..2d91f86117e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.collections.Pair; import com.yahoo.config.FileReference; @@ -46,9 +47,11 @@ import java.util.regex.Pattern; public class RankProfilesConfigImporter { private final FileAcquirer fileAcquirer; + private final OnnxEvaluatorCache onnxEvaluatorCache; - public RankProfilesConfigImporter(FileAcquirer fileAcquirer) { + public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxEvaluatorCache onnxEvaluatorCache) { this.fileAcquirer = fileAcquirer; + this.onnxEvaluatorCache = onnxEvaluatorCache; } /** @@ -183,7 +186,7 @@ public class RankProfilesConfigImporter { options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required()); - return new OnnxModel(name, file, options); + return new OnnxModel(name, file, options, onnxEvaluatorCache); } catch (InterruptedException e) { throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java index c11f4764678..bfba5ae24c4 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.io.GrowableByteBuffer; @@ -24,7 +25,7 @@ public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesC private final Path constantsPath; public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) { - super(fileAcquirer); + super(fileAcquirer, new OnnxEvaluatorCache()); this.constantsPath = constantsPath; } |