diff options
10 files changed, 186 insertions, 17 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java index 49292bd6df7..57110b2431e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java @@ -44,6 +44,7 @@ public class ContainerModelEvaluation implements public ContainerModelEvaluation(ApplicationContainerCluster cluster, RankProfileList rankProfileList) { this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null"); cluster.addSimpleComponent(EVALUATOR_NAME, null, EVALUATION_BUNDLE_NAME); + cluster.addSimpleComponent("ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache", null, INTEGRATION_BUNDLE_NAME); cluster.addComponent(ContainerModelEvaluation.getHandler()); } diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index a5bda6e1c21..9fd25ac115b 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -47,7 +47,9 @@ }, "ai.vespa.models.evaluation.Model" : { "superClass" : "java.lang.Object", - "interfaces" : [ ], + "interfaces" : [ + "java.lang.AutoCloseable" + ], "attributes" : [ "public" ], @@ -56,7 +58,8 @@ "public java.lang.String name()", "public java.util.List functions()", "public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String[])", - "public java.lang.String toString()" + "public java.lang.String toString()", + "public void close()" ], "fields" : [ ] }, @@ -67,12 +70,14 @@ "public" ], "methods" : [ + "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)", "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)", "public void <init>(ai.vespa.models.evaluation.RankProfilesConfigImporter, com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)", "public void <init>(java.util.Map)", "public java.util.Map models()", "public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String, java.lang.String[])", - "public ai.vespa.models.evaluation.Model requireModel(java.lang.String)" + "public ai.vespa.models.evaluation.Model requireModel(java.lang.String)", + "public void deconstruct()" ], "fields" : [ ] }, @@ -83,7 +88,7 @@ "public" ], "methods" : [ - "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer)", + "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)", "public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)", "protected final java.lang.String readExpressionFromFile(java.io.File)", "protected com.yahoo.searchlib.rankingexpression.RankingExpression readExpressionFromFile(java.lang.String, com.yahoo.config.FileReference)", diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index d66d0330ea6..c317cdc5922 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -10,7 +10,6 @@ import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -22,7 +21,7 @@ import java.util.stream.Collectors; * @author bratseth */ @Beta -public class Model { +public class Model implements AutoCloseable { /** The prefix generated by model-integration/../IntermediateOperation */ private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; @@ -43,6 +42,8 @@ public class Model { private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); + private final List<Runnable> closeActions; + /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection<ExpressionFunction> functions) { this(name, @@ -101,6 +102,7 @@ public class Model { // Optimize functions this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream() .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName()))))); + this.closeActions = onnxModels.stream().map(o -> (Runnable)o::close).toList(); } /** Returns an optimized version of the given function */ @@ -223,4 +225,5 @@ public class Model { @Override public String toString() { return "model '" + name + "'"; } + @Override public void close() { closeActions.forEach(Runnable::run); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 28b613ca281..74233853ae9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; import com.yahoo.component.AbstractComponent; @@ -30,8 +31,17 @@ public class ModelsEvaluator extends AbstractComponent { RankingConstantsConfig constantsConfig, RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig, + FileAcquirer fileAcquirer, + OnnxEvaluatorCache cache) { + this(new RankProfilesConfigImporter(fileAcquirer, cache), config, constantsConfig, expressionsConfig, onnxModelsConfig); + } + + public ModelsEvaluator(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + RankingExpressionsConfig expressionsConfig, + OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer) { - this(new RankProfilesConfigImporter(fileAcquirer), config, constantsConfig, expressionsConfig, onnxModelsConfig); + this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxEvaluatorCache()); } public ModelsEvaluator(RankProfilesConfigImporter importer, @@ -69,4 +79,5 @@ public class ModelsEvaluator extends AbstractComponent { return model; } + @Override public void deconstruct() { models.values().forEach(Model::close); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index 19a9a1dccd5..ac66b1151f3 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -14,18 +15,20 @@ import java.util.Map; * * @author lesters */ -class OnnxModel { +class OnnxModel implements AutoCloseable { private final String name; private final File modelFile; private final OnnxEvaluatorOptions options; + private final OnnxEvaluatorCache cache; - private OnnxEvaluator evaluator; + private OnnxEvaluatorCache.ReferencedEvaluator referencedEvaluator; - OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options) { + OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxEvaluatorCache cache) { this.name = name; this.modelFile = modelFile; this.options = options; + this.cache = cache; } public String name() { @@ -33,8 +36,8 @@ class OnnxModel { } public void load() { - if (evaluator == null) { - evaluator = new OnnxEvaluator(modelFile.getPath(), options); + if (referencedEvaluator == null) { + referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options); } } @@ -51,10 +54,11 @@ class OnnxModel { } private OnnxEvaluator evaluator() { - if (evaluator == null) { + if (referencedEvaluator == null) { throw new IllegalStateException("ONNX model has not been loaded."); } - return evaluator; + return referencedEvaluator.evaluator(); } + @Override public void close() { referencedEvaluator.close(); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index e8aae24ca9e..2d91f86117e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.collections.Pair; import com.yahoo.config.FileReference; @@ -46,9 +47,11 @@ import java.util.regex.Pattern; public class RankProfilesConfigImporter { private final FileAcquirer fileAcquirer; + private final OnnxEvaluatorCache onnxEvaluatorCache; - public RankProfilesConfigImporter(FileAcquirer fileAcquirer) { + public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxEvaluatorCache onnxEvaluatorCache) { this.fileAcquirer = fileAcquirer; + this.onnxEvaluatorCache = onnxEvaluatorCache; } /** @@ -183,7 +186,7 @@ public class RankProfilesConfigImporter { options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required()); - return new OnnxModel(name, file, options); + return new OnnxModel(name, file, options, onnxEvaluatorCache); } catch (InterruptedException e) { throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java index c11f4764678..bfba5ae24c4 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.io.GrowableByteBuffer; @@ -24,7 +25,7 @@ public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesC private final Path constantsPath; public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) { - super(fileAcquirer); + super(fileAcquirer, new OnnxEvaluatorCache()); this.constantsPath = constantsPath; } diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 1302984a314..8f26758cf65 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -105,6 +105,21 @@ <scope>test</scope> </dependency> <dependency> + <groupId>org.junit.vintage</groupId> + <artifactId>junit-vintage-engine</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <scope>test</scope> diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java new file mode 100644 index 00000000000..b92ce24a6b4 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java @@ -0,0 +1,88 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import com.yahoo.jdisc.AbstractResource; +import com.yahoo.jdisc.ReferencedResource; +import com.yahoo.jdisc.ResourceReference; + +import javax.inject.Inject; +import java.util.HashMap; +import java.util.Map; + +/** + * Caches instances of {@link OnnxEvaluator}. + * + * @author bjorncs + */ +public class OnnxEvaluatorCache { + + // For mocking OnnxEvaluator in tests + @FunctionalInterface interface OnnxEvaluatorFactory { OnnxEvaluator create(String path, OnnxEvaluatorOptions opts); } + + private final Object monitor = new Object(); + private final Map<Id, SharedEvaluator> cache = new HashMap<>(); + private final OnnxEvaluatorFactory factory; + + @Inject public OnnxEvaluatorCache() { this(OnnxEvaluator::new); } + + OnnxEvaluatorCache(OnnxEvaluatorFactory factory) { this.factory = factory; } + + public ReferencedEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) { + synchronized (monitor) { + var id = new Id(modelPath, options); + var sharedInstance = cache.get(id); + if (sharedInstance == null) { + return newInstance(id); + } else { + ResourceReference reference; + try { + // refer() may throw if last reference was just released, but instance has not yet been removed from cache + reference = sharedInstance.refer(id); + } catch (IllegalStateException e) { + return newInstance(id); + } + return new ReferencedEvaluator(sharedInstance, reference); + } + } + } + + int size() { return cache.size(); } + + private ReferencedEvaluator newInstance(Id id) { + var evaluator = new SharedEvaluator(id, factory.create(id.modelPath, id.options)); + cache.put(id, evaluator); + var referenced = new ReferencedEvaluator(evaluator, evaluator.refer(id)); + // Release "main" reference to ensure that evaluator is destroyed when last external reference is released + evaluator.release(); + return referenced; + } + + // We assume options are never modified after being passed to cache + record Id(String modelPath, OnnxEvaluatorOptions options) {} + + public class ReferencedEvaluator extends ReferencedResource<SharedEvaluator> { + ReferencedEvaluator(SharedEvaluator resource, ResourceReference reference) { super(resource, reference); } + + public OnnxEvaluator evaluator() { return getResource().instance(); } + } + + public class SharedEvaluator extends AbstractResource { + private final Id id; + private final OnnxEvaluator instance; + + private SharedEvaluator(Id id, OnnxEvaluator instance) { + this.id = id; + this.instance = instance; + } + + public OnnxEvaluator instance() { return instance; } + + @Override + protected void destroy() { + synchronized (OnnxEvaluatorCache.this) { cache.remove(id); } + instance.close(); + } + } + +} diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java new file mode 100644 index 00000000000..acce660f466 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java @@ -0,0 +1,38 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.mockito.Mockito.mock; + +/** + * @author bjorncs + */ +class OnnxEvaluatorCacheTest { + + @Test + void reuses_instance_while_in_use() { + var cache = new OnnxEvaluatorCache((__, ___) -> mock(OnnxEvaluator.class)); + var referencedEvaluator1 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); + var referencedEvaluator2 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); + var referencedEvaluator3 = cache.evaluatorOf("model2", new OnnxEvaluatorOptions()); + assertSame(referencedEvaluator1.evaluator(), referencedEvaluator2.evaluator()); + assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator3.evaluator()); + assertEquals(2, cache.size()); + referencedEvaluator1.close(); + referencedEvaluator2.close(); + assertEquals(1, cache.size()); + referencedEvaluator3.close(); + assertEquals(0, cache.size()); + var referencedEvaluator4 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); + assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator4.evaluator()); + assertEquals(1, cache.size()); + referencedEvaluator4.close(); + assertEquals(0, cache.size()); + } + +}
\ No newline at end of file |