diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-02-22 11:51:06 +0100 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-02-22 12:04:35 +0100 |
commit | 7d69590e78f7e29dd7288a401e71732211a3b5dd (patch) | |
tree | 74286f892f873ee0309a72529447f2e575cbb15e /model-integration | |
parent | c5513d25475c78ce6a3ecd5e03b278f3eebca481 (diff) |
Cache Onnx model instances
Manage lifecycle of OnnxEvaluator instances explicitly to allow
instances to be cached without use WeakHashmap/finalizers.
Inject shared Onnx model cache in ModelsEvaluator.
Diffstat (limited to 'model-integration')
3 files changed, 141 insertions, 0 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 1302984a314..8f26758cf65 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -105,6 +105,21 @@ <scope>test</scope> </dependency> <dependency> + <groupId>org.junit.vintage</groupId> + <artifactId>junit-vintage-engine</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <scope>test</scope> diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java new file mode 100644 index 00000000000..b92ce24a6b4 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java @@ -0,0 +1,88 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import com.yahoo.jdisc.AbstractResource; +import com.yahoo.jdisc.ReferencedResource; +import com.yahoo.jdisc.ResourceReference; + +import javax.inject.Inject; +import java.util.HashMap; +import java.util.Map; + +/** + * Caches instances of {@link OnnxEvaluator}. + * + * @author bjorncs + */ +public class OnnxEvaluatorCache { + + // For mocking OnnxEvaluator in tests + @FunctionalInterface interface OnnxEvaluatorFactory { OnnxEvaluator create(String path, OnnxEvaluatorOptions opts); } + + private final Object monitor = new Object(); + private final Map<Id, SharedEvaluator> cache = new HashMap<>(); + private final OnnxEvaluatorFactory factory; + + @Inject public OnnxEvaluatorCache() { this(OnnxEvaluator::new); } + + OnnxEvaluatorCache(OnnxEvaluatorFactory factory) { this.factory = factory; } + + public ReferencedEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) { + synchronized (monitor) { + var id = new Id(modelPath, options); + var sharedInstance = cache.get(id); + if (sharedInstance == null) { + return newInstance(id); + } else { + ResourceReference reference; + try { + // refer() may throw if last reference was just released, but instance has not yet been removed from cache + reference = sharedInstance.refer(id); + } catch (IllegalStateException e) { + return newInstance(id); + } + return new ReferencedEvaluator(sharedInstance, reference); + } + } + } + + int size() { return cache.size(); } + + private ReferencedEvaluator newInstance(Id id) { + var evaluator = new SharedEvaluator(id, factory.create(id.modelPath, id.options)); + cache.put(id, evaluator); + var referenced = new ReferencedEvaluator(evaluator, evaluator.refer(id)); + // Release "main" reference to ensure that evaluator is destroyed when last external reference is released + evaluator.release(); + return referenced; + } + + // We assume options are never modified after being passed to cache + record Id(String modelPath, OnnxEvaluatorOptions options) {} + + public class ReferencedEvaluator extends ReferencedResource<SharedEvaluator> { + ReferencedEvaluator(SharedEvaluator resource, ResourceReference reference) { super(resource, reference); } + + public OnnxEvaluator evaluator() { return getResource().instance(); } + } + + public class SharedEvaluator extends AbstractResource { + private final Id id; + private final OnnxEvaluator instance; + + private SharedEvaluator(Id id, OnnxEvaluator instance) { + this.id = id; + this.instance = instance; + } + + public OnnxEvaluator instance() { return instance; } + + @Override + protected void destroy() { + synchronized (OnnxEvaluatorCache.this) { cache.remove(id); } + instance.close(); + } + } + +} diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java new file mode 100644 index 00000000000..acce660f466 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java @@ -0,0 +1,38 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.mockito.Mockito.mock; + +/** + * @author bjorncs + */ +class OnnxEvaluatorCacheTest { + + @Test + void reuses_instance_while_in_use() { + var cache = new OnnxEvaluatorCache((__, ___) -> mock(OnnxEvaluator.class)); + var referencedEvaluator1 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); + var referencedEvaluator2 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); + var referencedEvaluator3 = cache.evaluatorOf("model2", new OnnxEvaluatorOptions()); + assertSame(referencedEvaluator1.evaluator(), referencedEvaluator2.evaluator()); + assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator3.evaluator()); + assertEquals(2, cache.size()); + referencedEvaluator1.close(); + referencedEvaluator2.close(); + assertEquals(1, cache.size()); + referencedEvaluator3.close(); + assertEquals(0, cache.size()); + var referencedEvaluator4 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); + assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator4.evaluator()); + assertEquals(1, cache.size()); + referencedEvaluator4.close(); + assertEquals(0, cache.size()); + } + +}
\ No newline at end of file |