aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-22 11:51:06 +0100
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-22 12:04:35 +0100
commit7d69590e78f7e29dd7288a401e71732211a3b5dd (patch)
tree74286f892f873ee0309a72529447f2e575cbb15e /model-integration
parentc5513d25475c78ce6a3ecd5e03b278f3eebca481 (diff)
Cache Onnx model instances
Manage lifecycle of OnnxEvaluator instances explicitly to allow instances to be cached without use WeakHashmap/finalizers. Inject shared Onnx model cache in ModelsEvaluator.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml15
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java88
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java38
3 files changed, 141 insertions, 0 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 1302984a314..8f26758cf65 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -105,6 +105,21 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.junit.vintage</groupId>
+ <artifactId>junit-vintage-engine</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.junit.jupiter</groupId>
+ <artifactId>junit-jupiter</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<scope>test</scope>
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java
new file mode 100644
index 00000000000..b92ce24a6b4
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java
@@ -0,0 +1,88 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import com.yahoo.jdisc.AbstractResource;
+import com.yahoo.jdisc.ReferencedResource;
+import com.yahoo.jdisc.ResourceReference;
+
+import javax.inject.Inject;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Caches instances of {@link OnnxEvaluator}.
+ *
+ * @author bjorncs
+ */
+public class OnnxEvaluatorCache {
+
+ // For mocking OnnxEvaluator in tests
+ @FunctionalInterface interface OnnxEvaluatorFactory { OnnxEvaluator create(String path, OnnxEvaluatorOptions opts); }
+
+ private final Object monitor = new Object();
+ private final Map<Id, SharedEvaluator> cache = new HashMap<>();
+ private final OnnxEvaluatorFactory factory;
+
+ @Inject public OnnxEvaluatorCache() { this(OnnxEvaluator::new); }
+
+ OnnxEvaluatorCache(OnnxEvaluatorFactory factory) { this.factory = factory; }
+
+ public ReferencedEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
+ synchronized (monitor) {
+ var id = new Id(modelPath, options);
+ var sharedInstance = cache.get(id);
+ if (sharedInstance == null) {
+ return newInstance(id);
+ } else {
+ ResourceReference reference;
+ try {
+ // refer() may throw if last reference was just released, but instance has not yet been removed from cache
+ reference = sharedInstance.refer(id);
+ } catch (IllegalStateException e) {
+ return newInstance(id);
+ }
+ return new ReferencedEvaluator(sharedInstance, reference);
+ }
+ }
+ }
+
+ int size() { return cache.size(); }
+
+ private ReferencedEvaluator newInstance(Id id) {
+ var evaluator = new SharedEvaluator(id, factory.create(id.modelPath, id.options));
+ cache.put(id, evaluator);
+ var referenced = new ReferencedEvaluator(evaluator, evaluator.refer(id));
+ // Release "main" reference to ensure that evaluator is destroyed when last external reference is released
+ evaluator.release();
+ return referenced;
+ }
+
+ // We assume options are never modified after being passed to cache
+ record Id(String modelPath, OnnxEvaluatorOptions options) {}
+
+ public class ReferencedEvaluator extends ReferencedResource<SharedEvaluator> {
+ ReferencedEvaluator(SharedEvaluator resource, ResourceReference reference) { super(resource, reference); }
+
+ public OnnxEvaluator evaluator() { return getResource().instance(); }
+ }
+
+ public class SharedEvaluator extends AbstractResource {
+ private final Id id;
+ private final OnnxEvaluator instance;
+
+ private SharedEvaluator(Id id, OnnxEvaluator instance) {
+ this.id = id;
+ this.instance = instance;
+ }
+
+ public OnnxEvaluator instance() { return instance; }
+
+ @Override
+ protected void destroy() {
+ synchronized (OnnxEvaluatorCache.this) { cache.remove(id); }
+ instance.close();
+ }
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
new file mode 100644
index 00000000000..acce660f466
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
@@ -0,0 +1,38 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.mockito.Mockito.mock;
+
+/**
+ * @author bjorncs
+ */
+class OnnxEvaluatorCacheTest {
+
+ @Test
+ void reuses_instance_while_in_use() {
+ var cache = new OnnxEvaluatorCache((__, ___) -> mock(OnnxEvaluator.class));
+ var referencedEvaluator1 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
+ var referencedEvaluator2 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
+ var referencedEvaluator3 = cache.evaluatorOf("model2", new OnnxEvaluatorOptions());
+ assertSame(referencedEvaluator1.evaluator(), referencedEvaluator2.evaluator());
+ assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator3.evaluator());
+ assertEquals(2, cache.size());
+ referencedEvaluator1.close();
+ referencedEvaluator2.close();
+ assertEquals(1, cache.size());
+ referencedEvaluator3.close();
+ assertEquals(0, cache.size());
+ var referencedEvaluator4 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
+ assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator4.evaluator());
+ assertEquals(1, cache.size());
+ referencedEvaluator4.close();
+ assertEquals(0, cache.size());
+ }
+
+} \ No newline at end of file