summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java1
-rw-r--r--model-evaluation/abi-spec.json13
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java7
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java13
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java18
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java7
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java3
-rw-r--r--model-integration/pom.xml15
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java88
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java38
10 files changed, 186 insertions, 17 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
index 49292bd6df7..57110b2431e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
@@ -44,6 +44,7 @@ public class ContainerModelEvaluation implements
public ContainerModelEvaluation(ApplicationContainerCluster cluster, RankProfileList rankProfileList) {
this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null");
cluster.addSimpleComponent(EVALUATOR_NAME, null, EVALUATION_BUNDLE_NAME);
+ cluster.addSimpleComponent("ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache", null, INTEGRATION_BUNDLE_NAME);
cluster.addComponent(ContainerModelEvaluation.getHandler());
}
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index a5bda6e1c21..9fd25ac115b 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -47,7 +47,9 @@
},
"ai.vespa.models.evaluation.Model" : {
"superClass" : "java.lang.Object",
- "interfaces" : [ ],
+ "interfaces" : [
+ "java.lang.AutoCloseable"
+ ],
"attributes" : [
"public"
],
@@ -56,7 +58,8 @@
"public java.lang.String name()",
"public java.util.List functions()",
"public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String[])",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public void close()"
],
"fields" : [ ]
},
@@ -67,12 +70,14 @@
"public"
],
"methods" : [
+ "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)",
"public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
"public void <init>(ai.vespa.models.evaluation.RankProfilesConfigImporter, com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)",
"public void <init>(java.util.Map)",
"public java.util.Map models()",
"public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String, java.lang.String[])",
- "public ai.vespa.models.evaluation.Model requireModel(java.lang.String)"
+ "public ai.vespa.models.evaluation.Model requireModel(java.lang.String)",
+ "public void deconstruct()"
],
"fields" : [ ]
},
@@ -83,7 +88,7 @@
"public"
],
"methods" : [
- "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer)",
+ "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)",
"public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)",
"protected final java.lang.String readExpressionFromFile(java.io.File)",
"protected com.yahoo.searchlib.rankingexpression.RankingExpression readExpressionFromFile(java.lang.String, com.yahoo.config.FileReference)",
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index d66d0330ea6..c317cdc5922 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -10,7 +10,6 @@ import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
-import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -22,7 +21,7 @@ import java.util.stream.Collectors;
* @author bratseth
*/
@Beta
-public class Model {
+public class Model implements AutoCloseable {
/** The prefix generated by model-integration/../IntermediateOperation */
private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
@@ -43,6 +42,8 @@ public class Model {
private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer();
+ private final List<Runnable> closeActions;
+
/** Programmatically create a model containing functions without constant of function references only */
public Model(String name, Collection<ExpressionFunction> functions) {
this(name,
@@ -101,6 +102,7 @@ public class Model {
// Optimize functions
this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream()
.collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName())))));
+ this.closeActions = onnxModels.stream().map(o -> (Runnable)o::close).toList();
}
/** Returns an optimized version of the given function */
@@ -223,4 +225,5 @@ public class Model {
@Override
public String toString() { return "model '" + name + "'"; }
+ @Override public void close() { closeActions.forEach(Runnable::run); }
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 28b613ca281..74233853ae9 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.AbstractComponent;
@@ -30,8 +31,17 @@ public class ModelsEvaluator extends AbstractComponent {
RankingConstantsConfig constantsConfig,
RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig,
+ FileAcquirer fileAcquirer,
+ OnnxEvaluatorCache cache) {
+ this(new RankProfilesConfigImporter(fileAcquirer, cache), config, constantsConfig, expressionsConfig, onnxModelsConfig);
+ }
+
+ public ModelsEvaluator(RankProfilesConfig config,
+ RankingConstantsConfig constantsConfig,
+ RankingExpressionsConfig expressionsConfig,
+ OnnxModelsConfig onnxModelsConfig,
FileAcquirer fileAcquirer) {
- this(new RankProfilesConfigImporter(fileAcquirer), config, constantsConfig, expressionsConfig, onnxModelsConfig);
+ this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxEvaluatorCache());
}
public ModelsEvaluator(RankProfilesConfigImporter importer,
@@ -69,4 +79,5 @@ public class ModelsEvaluator extends AbstractComponent {
return model;
}
+ @Override public void deconstruct() { models.values().forEach(Model::close); }
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index 19a9a1dccd5..ac66b1151f3 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -2,6 +2,7 @@
package ai.vespa.models.evaluation;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -14,18 +15,20 @@ import java.util.Map;
*
* @author lesters
*/
-class OnnxModel {
+class OnnxModel implements AutoCloseable {
private final String name;
private final File modelFile;
private final OnnxEvaluatorOptions options;
+ private final OnnxEvaluatorCache cache;
- private OnnxEvaluator evaluator;
+ private OnnxEvaluatorCache.ReferencedEvaluator referencedEvaluator;
- OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options) {
+ OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxEvaluatorCache cache) {
this.name = name;
this.modelFile = modelFile;
this.options = options;
+ this.cache = cache;
}
public String name() {
@@ -33,8 +36,8 @@ class OnnxModel {
}
public void load() {
- if (evaluator == null) {
- evaluator = new OnnxEvaluator(modelFile.getPath(), options);
+ if (referencedEvaluator == null) {
+ referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options);
}
}
@@ -51,10 +54,11 @@ class OnnxModel {
}
private OnnxEvaluator evaluator() {
- if (evaluator == null) {
+ if (referencedEvaluator == null) {
throw new IllegalStateException("ONNX model has not been loaded.");
}
- return evaluator;
+ return referencedEvaluator.evaluator();
}
+ @Override public void close() { referencedEvaluator.close(); }
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index e8aae24ca9e..2d91f86117e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
@@ -46,9 +47,11 @@ import java.util.regex.Pattern;
public class RankProfilesConfigImporter {
private final FileAcquirer fileAcquirer;
+ private final OnnxEvaluatorCache onnxEvaluatorCache;
- public RankProfilesConfigImporter(FileAcquirer fileAcquirer) {
+ public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxEvaluatorCache onnxEvaluatorCache) {
this.fileAcquirer = fileAcquirer;
+ this.onnxEvaluatorCache = onnxEvaluatorCache;
}
/**
@@ -183,7 +186,7 @@ public class RankProfilesConfigImporter {
options.setInterOpThreads(onnxModelConfig.stateless_interop_threads());
options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads());
options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required());
- return new OnnxModel(name, file, options);
+ return new OnnxModel(name, file, options, onnxEvaluatorCache);
} catch (InterruptedException e) {
throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
index c11f4764678..bfba5ae24c4 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
@@ -24,7 +25,7 @@ public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesC
private final Path constantsPath;
public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) {
- super(fileAcquirer);
+ super(fileAcquirer, new OnnxEvaluatorCache());
this.constantsPath = constantsPath;
}
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 1302984a314..8f26758cf65 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -105,6 +105,21 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.junit.vintage</groupId>
+ <artifactId>junit-vintage-engine</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.junit.jupiter</groupId>
+ <artifactId>junit-jupiter</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<scope>test</scope>
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java
new file mode 100644
index 00000000000..b92ce24a6b4
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java
@@ -0,0 +1,88 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import com.yahoo.jdisc.AbstractResource;
+import com.yahoo.jdisc.ReferencedResource;
+import com.yahoo.jdisc.ResourceReference;
+
+import javax.inject.Inject;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Caches instances of {@link OnnxEvaluator}.
+ *
+ * @author bjorncs
+ */
+public class OnnxEvaluatorCache {
+
+ // For mocking OnnxEvaluator in tests
+ @FunctionalInterface interface OnnxEvaluatorFactory { OnnxEvaluator create(String path, OnnxEvaluatorOptions opts); }
+
+ private final Object monitor = new Object();
+ private final Map<Id, SharedEvaluator> cache = new HashMap<>();
+ private final OnnxEvaluatorFactory factory;
+
+ @Inject public OnnxEvaluatorCache() { this(OnnxEvaluator::new); }
+
+ OnnxEvaluatorCache(OnnxEvaluatorFactory factory) { this.factory = factory; }
+
+ public ReferencedEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
+ synchronized (monitor) {
+ var id = new Id(modelPath, options);
+ var sharedInstance = cache.get(id);
+ if (sharedInstance == null) {
+ return newInstance(id);
+ } else {
+ ResourceReference reference;
+ try {
+ // refer() may throw if last reference was just released, but instance has not yet been removed from cache
+ reference = sharedInstance.refer(id);
+ } catch (IllegalStateException e) {
+ return newInstance(id);
+ }
+ return new ReferencedEvaluator(sharedInstance, reference);
+ }
+ }
+ }
+
+ int size() { return cache.size(); }
+
+ private ReferencedEvaluator newInstance(Id id) {
+ var evaluator = new SharedEvaluator(id, factory.create(id.modelPath, id.options));
+ cache.put(id, evaluator);
+ var referenced = new ReferencedEvaluator(evaluator, evaluator.refer(id));
+ // Release "main" reference to ensure that evaluator is destroyed when last external reference is released
+ evaluator.release();
+ return referenced;
+ }
+
+ // We assume options are never modified after being passed to cache
+ record Id(String modelPath, OnnxEvaluatorOptions options) {}
+
+ public class ReferencedEvaluator extends ReferencedResource<SharedEvaluator> {
+ ReferencedEvaluator(SharedEvaluator resource, ResourceReference reference) { super(resource, reference); }
+
+ public OnnxEvaluator evaluator() { return getResource().instance(); }
+ }
+
+ public class SharedEvaluator extends AbstractResource {
+ private final Id id;
+ private final OnnxEvaluator instance;
+
+ private SharedEvaluator(Id id, OnnxEvaluator instance) {
+ this.id = id;
+ this.instance = instance;
+ }
+
+ public OnnxEvaluator instance() { return instance; }
+
+ @Override
+ protected void destroy() {
+ synchronized (OnnxEvaluatorCache.this) { cache.remove(id); }
+ instance.close();
+ }
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
new file mode 100644
index 00000000000..acce660f466
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
@@ -0,0 +1,38 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.mockito.Mockito.mock;
+
+/**
+ * @author bjorncs
+ */
+class OnnxEvaluatorCacheTest {
+
+ @Test
+ void reuses_instance_while_in_use() {
+ var cache = new OnnxEvaluatorCache((__, ___) -> mock(OnnxEvaluator.class));
+ var referencedEvaluator1 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
+ var referencedEvaluator2 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
+ var referencedEvaluator3 = cache.evaluatorOf("model2", new OnnxEvaluatorOptions());
+ assertSame(referencedEvaluator1.evaluator(), referencedEvaluator2.evaluator());
+ assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator3.evaluator());
+ assertEquals(2, cache.size());
+ referencedEvaluator1.close();
+ referencedEvaluator2.close();
+ assertEquals(1, cache.size());
+ referencedEvaluator3.close();
+ assertEquals(0, cache.size());
+ var referencedEvaluator4 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
+ assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator4.evaluator());
+ assertEquals(1, cache.size());
+ referencedEvaluator4.close();
+ assertEquals(0, cache.size());
+ }
+
+} \ No newline at end of file