diff options
Diffstat (limited to 'model-evaluation/src')
5 files changed, 35 insertions, 13 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index d66d0330ea6..c317cdc5922 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -10,7 +10,6 @@ import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -22,7 +21,7 @@ import java.util.stream.Collectors; * @author bratseth */ @Beta -public class Model { +public class Model implements AutoCloseable { /** The prefix generated by model-integration/../IntermediateOperation */ private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; @@ -43,6 +42,8 @@ public class Model { private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); + private final List<Runnable> closeActions; + /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection<ExpressionFunction> functions) { this(name, @@ -101,6 +102,7 @@ public class Model { // Optimize functions this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream() .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName()))))); + this.closeActions = onnxModels.stream().map(o -> (Runnable)o::close).toList(); } /** Returns an optimized version of the given function */ @@ -223,4 +225,5 @@ public class Model { @Override public String toString() { return "model '" + name + "'"; } + @Override public void close() { closeActions.forEach(Runnable::run); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 28b613ca281..74233853ae9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; import com.yahoo.component.AbstractComponent; @@ -30,8 +31,17 @@ public class ModelsEvaluator extends AbstractComponent { RankingConstantsConfig constantsConfig, RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig, + FileAcquirer fileAcquirer, + OnnxEvaluatorCache cache) { + this(new RankProfilesConfigImporter(fileAcquirer, cache), config, constantsConfig, expressionsConfig, onnxModelsConfig); + } + + public ModelsEvaluator(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + RankingExpressionsConfig expressionsConfig, + OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer) { - this(new RankProfilesConfigImporter(fileAcquirer), config, constantsConfig, expressionsConfig, onnxModelsConfig); + this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxEvaluatorCache()); } public ModelsEvaluator(RankProfilesConfigImporter importer, @@ -69,4 +79,5 @@ public class ModelsEvaluator extends AbstractComponent { return model; } + @Override public void deconstruct() { models.values().forEach(Model::close); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index 19a9a1dccd5..ac66b1151f3 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -14,18 +15,20 @@ import java.util.Map; * * @author lesters */ -class OnnxModel { +class OnnxModel implements AutoCloseable { private final String name; private final File modelFile; private final OnnxEvaluatorOptions options; + private final OnnxEvaluatorCache cache; - private OnnxEvaluator evaluator; + private OnnxEvaluatorCache.ReferencedEvaluator referencedEvaluator; - OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options) { + OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxEvaluatorCache cache) { this.name = name; this.modelFile = modelFile; this.options = options; + this.cache = cache; } public String name() { @@ -33,8 +36,8 @@ class OnnxModel { } public void load() { - if (evaluator == null) { - evaluator = new OnnxEvaluator(modelFile.getPath(), options); + if (referencedEvaluator == null) { + referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options); } } @@ -51,10 +54,11 @@ class OnnxModel { } private OnnxEvaluator evaluator() { - if (evaluator == null) { + if (referencedEvaluator == null) { throw new IllegalStateException("ONNX model has not been loaded."); } - return evaluator; + return referencedEvaluator.evaluator(); } + @Override public void close() { referencedEvaluator.close(); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index e8aae24ca9e..2d91f86117e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.collections.Pair; import com.yahoo.config.FileReference; @@ -46,9 +47,11 @@ import java.util.regex.Pattern; public class RankProfilesConfigImporter { private final FileAcquirer fileAcquirer; + private final OnnxEvaluatorCache onnxEvaluatorCache; - public RankProfilesConfigImporter(FileAcquirer fileAcquirer) { + public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxEvaluatorCache onnxEvaluatorCache) { this.fileAcquirer = fileAcquirer; + this.onnxEvaluatorCache = onnxEvaluatorCache; } /** @@ -183,7 +186,7 @@ public class RankProfilesConfigImporter { options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required()); - return new OnnxModel(name, file, options); + return new OnnxModel(name, file, options, onnxEvaluatorCache); } catch (InterruptedException e) { throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java index c11f4764678..bfba5ae24c4 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.io.GrowableByteBuffer; @@ -24,7 +25,7 @@ public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesC private final Path constantsPath; public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) { - super(fileAcquirer); + super(fileAcquirer, new OnnxEvaluatorCache()); this.constantsPath = constantsPath; } |