aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 17:02:23 +0100
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 18:13:08 +0100
commit5271b5d7241aa2aa2538b2072b8cae9b8f3d689a (patch)
tree12f025b12e86e5f9490b74dd2cae68283f779e67 /model-integration
parent6b40c6053b8542ae20a5bbe669f84f2d478fd697 (diff)
Replace `OnnxEvaluatorCache` with OnnxRuntime
Require an `OnnxRuntime` instance to create `OnnxEvaluator` instances. Cache underlying `OrtSession` instead of `OnnxEvaluator`. Move static helpers for checking Onnx runtime availability from `OnnxEvaluator` to `OnnxRuntime`.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml6
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java23
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/Generator.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java68
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java88
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java170
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java15
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java19
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java16
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java10
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java38
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java28
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java48
14 files changed, 325 insertions, 228 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 8f26758cf65..9bb60827a68 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -69,6 +69,12 @@
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>component</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>net.java.dev.jna</groupId>
<artifactId>jna</artifactId>
<scope>provided</scope>
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index 002350ce3cf..b0b4f871163 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -2,8 +2,10 @@ package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
-import com.yahoo.embedding.BertBaseEmbedderConfig;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.annotation.Inject;
+import com.yahoo.embedding.BertBaseEmbedderConfig;
+import com.yahoo.jdisc.AbstractResource;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.wordpiece.WordPieceEmbedder;
import com.yahoo.tensor.IndexedTensor;
@@ -28,7 +30,7 @@ import java.util.Map;
*
* @author lesters
*/
-public class BertBaseEmbedder implements Embedder {
+public class BertBaseEmbedder extends AbstractResource implements Embedder {
private final static int TOKEN_CLS = 101; // [CLS]
private final static int TOKEN_SEP = 102; // [SEP]
@@ -44,7 +46,7 @@ public class BertBaseEmbedder implements Embedder {
private final OnnxEvaluator evaluator;
@Inject
- public BertBaseEmbedder(BertBaseEmbedderConfig config) {
+ public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) {
maxTokens = config.transformerMaxTokens();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
@@ -58,7 +60,7 @@ public class BertBaseEmbedder implements Embedder {
options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads()));
tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build();
- evaluator = new OnnxEvaluator(config.transformerModel().toString(), options);
+ this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options);
validateModel();
}
@@ -100,6 +102,8 @@ public class BertBaseEmbedder implements Embedder {
return embedTokens(tokens, type);
}
+ @Override protected void destroy() { evaluator.close(); }
+
Tensor embedTokens(List<Integer> tokens, TensorType type) {
Tensor inputSequence = createTensorRepresentation(tokens, "d1");
Tensor attentionMask = createAttentionMask(inputSequence);
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 81150fe99b0..bad4bb5c9b3 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -3,22 +3,25 @@ package ai.vespa.embedding.huggingface;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.annotation.Inject;
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.jdisc.AbstractResource;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-import java.io.*;
+import java.io.IOException;
import java.nio.file.Paths;
-import java.util.*;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
-import org.slf4j.LoggerFactory;
-import org.slf4j.Logger;
-
-public class HuggingFaceEmbedder implements Embedder {
+public class HuggingFaceEmbedder extends AbstractResource implements Embedder {
private static final Logger LOG = LoggerFactory.getLogger(HuggingFaceEmbedder.class.getName());
@@ -30,7 +33,7 @@ public class HuggingFaceEmbedder implements Embedder {
private final OnnxEvaluator evaluator;
@Inject
- public HuggingFaceEmbedder(HuggingFaceEmbedderConfig config) throws IOException {
+ public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException {
maxTokens = config.transformerMaxTokens();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
@@ -48,7 +51,7 @@ public class HuggingFaceEmbedder implements Embedder {
LOG.info("Could not initialize the tokenizer");
throw new IOException("Could not initialize the tokenizer.");
}
- evaluator = new OnnxEvaluator(config.transformerModel().toString());
+ evaluator = onnx.evaluatorOf(config.transformerModel().toString());
validateModel();
}
@@ -83,6 +86,8 @@ public class HuggingFaceEmbedder implements Embedder {
return tokenIds;
}
+ @Override protected void destroy() { evaluator.close(); }
+
public List<Integer> longToInteger(long[] values) {
return Arrays.stream(values)
.boxed().map(Long::intValue)
diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java
index ed231a5e94c..a08e2006e2c 100644
--- a/model-integration/src/main/java/ai/vespa/llm/Generator.java
+++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java
@@ -2,7 +2,9 @@ package ai.vespa.llm;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.annotation.Inject;
+import com.yahoo.jdisc.AbstractResource;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.sentencepiece.SentencePieceEmbedder;
import com.yahoo.llm.GeneratorConfig;
@@ -25,7 +27,7 @@ import java.util.Map;
*
* @author lesters
*/
-public class Generator {
+public class Generator extends AbstractResource {
private final static int TOKEN_EOS = 1; // end of sequence
@@ -46,7 +48,7 @@ public class Generator {
private final OnnxEvaluator decoder;
@Inject
- public Generator(GeneratorConfig config) {
+ public Generator(OnnxRuntime onnx, GeneratorConfig config) {
// Set up tokenizer
tokenizer = new SentencePieceEmbedder.Builder(config.tokenizerModel().toString()).build();
tokenizerMaxTokens = config.tokenizerMaxTokens();
@@ -61,7 +63,7 @@ public class Generator {
encoderOptions.setInterOpThreads(modifyThreadCount(config.encoderOnnxInterOpThreads()));
encoderOptions.setIntraOpThreads(modifyThreadCount(config.encoderOnnxIntraOpThreads()));
- encoder = new OnnxEvaluator(config.encoderModel().toString(), encoderOptions);
+ encoder = onnx.evaluatorOf(config.encoderModel().toString(), encoderOptions);
// Set up decoder
decoderInputIdsName = config.decoderModelInputIdsName();
@@ -74,7 +76,7 @@ public class Generator {
decoderOptions.setInterOpThreads(modifyThreadCount(config.decoderOnnxInterOpThreads()));
decoderOptions.setIntraOpThreads(modifyThreadCount(config.decoderOnnxIntraOpThreads()));
- decoder = new OnnxEvaluator(config.decoderModel().toString(), decoderOptions);
+ decoder = onnx.evaluatorOf(config.decoderModel().toString(), decoderOptions);
validateModels();
}
@@ -99,6 +101,8 @@ public class Generator {
return generate(prompt, new GeneratorOptions());
}
+ @Override protected void destroy() { encoder.close(); decoder.close(); }
+
private String generateNotImplemented(GeneratorOptions options) {
throw new UnsupportedOperationException("Search method '" + options.getSearchMethod() + "' is currently not implemented");
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index c2d97e37074..7cdc27b6d63 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -5,9 +5,9 @@ package ai.vespa.modelintegration.evaluator;
import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
-import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime.ReferencedOrtSession;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -15,6 +15,8 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+import static ai.vespa.modelintegration.evaluator.OnnxRuntime.isCudaError;
+
/**
* Evaluates an ONNX Model by deferring to ONNX Runtime.
@@ -23,24 +25,18 @@ import java.util.Map;
*/
public class OnnxEvaluator implements AutoCloseable {
- private final OrtEnvironment environment;
- private final OrtSession session;
-
- public OnnxEvaluator(String modelPath) {
- this(modelPath, null);
- }
+ private final ReferencedOrtSession session;
- public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
- environment = OrtEnvironment.getEnvironment();
- session = createSession(modelPath, environment, options, true);
+ OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options, OnnxRuntime runtime) {
+ session = createSession(modelPath, runtime, options, true);
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
Map<String, OnnxTensor> onnxInputs = null;
try {
output = mapToInternalName(output);
- onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
- try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) {
+ onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), session.instance());
+ try (OrtSession.Result result = session.instance().run(onnxInputs, Collections.singleton(output))) {
return TensorConverter.toVespaTensor(result.get(0));
}
} catch (OrtException e) {
@@ -55,9 +51,9 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
Map<String, OnnxTensor> onnxInputs = null;
try {
- onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
+ onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), session.instance());
Map<String, Tensor> outputs = new HashMap<>();
- try (OrtSession.Result result = session.run(onnxInputs)) {
+ try (OrtSession.Result result = session.instance().run(onnxInputs)) {
for (Map.Entry<String, OnnxValue> output : result) {
String mapped = TensorConverter.asValidName(output.getKey());
outputs.put(mapped, TensorConverter.toVespaTensor(output.getValue()));
@@ -88,7 +84,7 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, IdAndType> getInputs() {
try {
- return toSpecMap(session.getInputInfo());
+ return toSpecMap(session.instance().getInputInfo());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
@@ -96,7 +92,7 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, IdAndType> getOutputs() {
try {
- return toSpecMap(session.getOutputInfo());
+ return toSpecMap(session.instance().getOutputInfo());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
@@ -104,7 +100,7 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, TensorType> getInputInfo() {
try {
- return TensorConverter.toVespaTypes(session.getInputInfo());
+ return TensorConverter.toVespaTypes(session.instance().getInputInfo());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
@@ -112,7 +108,7 @@ public class OnnxEvaluator implements AutoCloseable {
public Map<String, TensorType> getOutputInfo() {
try {
- return TensorConverter.toVespaTypes(session.getOutputInfo());
+ return TensorConverter.toVespaTypes(session.instance().getOutputInfo());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
@@ -122,26 +118,26 @@ public class OnnxEvaluator implements AutoCloseable {
public void close() throws IllegalStateException {
try {
session.close();
- } catch (OrtException e) {
+ } catch (UncheckedOrtException e) {
throw new IllegalStateException("Failed to close ONNX session", e);
} catch (IllegalStateException e) {
throw new IllegalStateException("Already closed", e);
}
}
- private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) {
+ private static ReferencedOrtSession createSession(String modelPath, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) {
if (options == null) {
options = new OnnxEvaluatorOptions();
}
try {
- return environment.createSession(modelPath, options.getOptions(tryCuda && options.requestingGpu()));
+ return runtime.acquireSession(modelPath, options, tryCuda && options.requestingGpu());
} catch (OrtException e) {
if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
throw new IllegalArgumentException("No such file: " + modelPath);
}
if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) {
// Failed in CUDA native code, but GPU device is optional, so we can proceed without it
- return createSession(modelPath, environment, options, false);
+ return createSession(modelPath, runtime, options, false);
}
if (isCudaError(e)) {
throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e);
@@ -150,34 +146,8 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
- private static boolean isCudaError(OrtException e) {
- return switch (e.getCode()) {
- case ORT_FAIL -> e.getMessage().contains("cudaError");
- case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA");
- default -> false;
- };
- }
-
- public static boolean isRuntimeAvailable() {
- return isRuntimeAvailable("");
- }
-
- public static boolean isRuntimeAvailable(String modelPath) {
- try {
- new OnnxEvaluator(modelPath);
- return true;
- } catch (IllegalArgumentException e) {
- if (e.getMessage().equals("No such file: ")) {
- return true;
- }
- return false;
- } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) {
- return false;
- }
- }
-
private String mapToInternalName(String outputName) throws OrtException {
- var info = session.getOutputInfo();
+ var info = session.instance().getOutputInfo();
var internalNames = info.keySet();
for (String name : internalNames) {
if (name.equals(outputName)) {
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java
deleted file mode 100644
index b92ce24a6b4..00000000000
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java
+++ /dev/null
@@ -1,88 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.modelintegration.evaluator;
-
-import com.yahoo.jdisc.AbstractResource;
-import com.yahoo.jdisc.ReferencedResource;
-import com.yahoo.jdisc.ResourceReference;
-
-import javax.inject.Inject;
-import java.util.HashMap;
-import java.util.Map;
-
-/**
- * Caches instances of {@link OnnxEvaluator}.
- *
- * @author bjorncs
- */
-public class OnnxEvaluatorCache {
-
- // For mocking OnnxEvaluator in tests
- @FunctionalInterface interface OnnxEvaluatorFactory { OnnxEvaluator create(String path, OnnxEvaluatorOptions opts); }
-
- private final Object monitor = new Object();
- private final Map<Id, SharedEvaluator> cache = new HashMap<>();
- private final OnnxEvaluatorFactory factory;
-
- @Inject public OnnxEvaluatorCache() { this(OnnxEvaluator::new); }
-
- OnnxEvaluatorCache(OnnxEvaluatorFactory factory) { this.factory = factory; }
-
- public ReferencedEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
- synchronized (monitor) {
- var id = new Id(modelPath, options);
- var sharedInstance = cache.get(id);
- if (sharedInstance == null) {
- return newInstance(id);
- } else {
- ResourceReference reference;
- try {
- // refer() may throw if last reference was just released, but instance has not yet been removed from cache
- reference = sharedInstance.refer(id);
- } catch (IllegalStateException e) {
- return newInstance(id);
- }
- return new ReferencedEvaluator(sharedInstance, reference);
- }
- }
- }
-
- int size() { return cache.size(); }
-
- private ReferencedEvaluator newInstance(Id id) {
- var evaluator = new SharedEvaluator(id, factory.create(id.modelPath, id.options));
- cache.put(id, evaluator);
- var referenced = new ReferencedEvaluator(evaluator, evaluator.refer(id));
- // Release "main" reference to ensure that evaluator is destroyed when last external reference is released
- evaluator.release();
- return referenced;
- }
-
- // We assume options are never modified after being passed to cache
- record Id(String modelPath, OnnxEvaluatorOptions options) {}
-
- public class ReferencedEvaluator extends ReferencedResource<SharedEvaluator> {
- ReferencedEvaluator(SharedEvaluator resource, ResourceReference reference) { super(resource, reference); }
-
- public OnnxEvaluator evaluator() { return getResource().instance(); }
- }
-
- public class SharedEvaluator extends AbstractResource {
- private final Id id;
- private final OnnxEvaluator instance;
-
- private SharedEvaluator(Id id, OnnxEvaluator instance) {
- this.id = id;
- this.instance = instance;
- }
-
- public OnnxEvaluator instance() { return instance; }
-
- @Override
- protected void destroy() {
- synchronized (OnnxEvaluatorCache.this) { cache.remove(id); }
- instance.close();
- }
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
new file mode 100644
index 00000000000..42830041c02
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
@@ -0,0 +1,170 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import com.yahoo.component.AbstractComponent;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.jdisc.ResourceReference;
+import com.yahoo.jdisc.refcount.DebugReferencesWithStack;
+import com.yahoo.jdisc.refcount.References;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import static com.yahoo.yolean.Exceptions.throwUnchecked;
+
+/**
+ * Provides ONNX runtime environment with session management.
+ *
+ * @author bjorncs
+ */
+public class OnnxRuntime extends AbstractComponent {
+
+ // For unit testing
+ @FunctionalInterface interface OrtSessionFactory {
+ OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException;
+ }
+
+ private static final Logger log = Logger.getLogger(OnnxRuntime.class.getName());
+
+ private static final OrtEnvironmentResult ortEnvironment = getOrtEnvironment();
+ private static final OrtSessionFactory defaultFactory = (path, opts) -> ortEnvironment().createSession(path, opts);
+
+ private final Object monitor = new Object();
+ private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>();
+ private final OrtSessionFactory factory;
+
+ @Inject public OnnxRuntime() { this(defaultFactory); }
+
+ OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; }
+
+ public OnnxEvaluator evaluatorOf(String modelPath) {
+ return new OnnxEvaluator(modelPath, null, this);
+ }
+
+ public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
+ return new OnnxEvaluator(modelPath, options, this);
+ }
+
+ public static OrtEnvironment ortEnvironment() {
+ if (ortEnvironment.env() != null) return ortEnvironment.env();
+ throw throwUnchecked(ortEnvironment.failure());
+ }
+
+ @Override
+ public void deconstruct() {
+ synchronized (monitor) {
+ sessions.forEach((id, sharedSession) -> {
+ int hash = System.identityHashCode(sharedSession.session());
+ var refs = sharedSession.references();
+ log.warning("Closing leaked session %s (%s) with %d outstanding references:\n%s"
+ .formatted(id, hash, refs.referenceCount(), refs.currentState()));
+ try {
+ sharedSession.session().close();
+ } catch (Exception e) {
+ log.log(Level.WARNING, "Failed to close session %s (%s)".formatted(id, hash), e);
+ }
+ });
+ sessions.clear();
+ }
+ }
+
+ private static OrtEnvironmentResult getOrtEnvironment() {
+ try {
+ return new OrtEnvironmentResult(OrtEnvironment.getEnvironment(), null);
+ } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) {
+ log.log(Level.FINE, e, () -> "Failed to load ONNX runtime");
+ return new OrtEnvironmentResult(null, e);
+ }
+ }
+
+ public static boolean isRuntimeAvailable() { return ortEnvironment.env() != null; }
+ public static boolean isRuntimeAvailable(String modelPath) {
+ if (!isRuntimeAvailable()) return false;
+ try {
+ // Expensive way of checking if runtime is available as it incurs the cost of loading the model if successful
+ defaultFactory.create(modelPath, new OnnxEvaluatorOptions().getOptions(false));
+ return true;
+ } catch (OrtException e) {
+ return e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE;
+ } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) {
+ return false;
+ }
+ }
+
+ static boolean isCudaError(OrtException e) {
+ return switch (e.getCode()) {
+ case ORT_FAIL -> e.getMessage().contains("cudaError");
+ case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA");
+ default -> false;
+ };
+ }
+
+ ReferencedOrtSession acquireSession(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException {
+ var sessionId = new OrtSessionId(modelPath, options, loadCuda);
+ synchronized (monitor) {
+ var sharedSession = sessions.get(sessionId);
+ if (sharedSession != null) {
+ return sharedSession.newReference();
+ }
+ }
+
+ // Note: identical models loaded simultaneously will result in duplicate session instances
+ var session = factory.create(modelPath, options.getOptions(loadCuda));
+ log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(session)));
+
+ var sharedSession = new SharedOrtSession(sessionId, session);
+ var referencedSession = sharedSession.newReference();
+ synchronized (monitor) { sessions.put(sessionId, sharedSession); }
+ sharedSession.references().release(); // Release initial reference
+ return referencedSession;
+ }
+
+ int sessionsCached() { synchronized(monitor) { return sessions.size(); } }
+
+ public static class ReferencedOrtSession implements AutoCloseable {
+ private final OrtSession instance;
+ private final ResourceReference ref;
+
+ public ReferencedOrtSession(OrtSession instance, ResourceReference ref) {
+ this.instance = instance;
+ this.ref = ref;
+ }
+
+ public OrtSession instance() { return instance; }
+ @Override public void close() { ref.close(); }
+ }
+
+ // Assumes options are never modified after being stored in `onnxSessions`
+ record OrtSessionId(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) {}
+
+ record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {}
+
+ private class SharedOrtSession {
+ private final OrtSessionId id;
+ private final OrtSession session;
+ private final References refs = new DebugReferencesWithStack(this::close);
+
+ SharedOrtSession(OrtSessionId id, OrtSession session) {
+ this.id = id;
+ this.session = session;
+ }
+
+ ReferencedOrtSession newReference() { return new ReferencedOrtSession(session, refs.refer(id)); }
+ References references() { return refs; }
+ OrtSession session() { return session; }
+
+ void close() {
+ try {
+ synchronized (OnnxRuntime.this.monitor) { sessions.remove(id); }
+ log.fine(() -> "Closing session (%s)".formatted(System.identityHashCode(session)));
+ session.close();
+ } catch (OrtException e) { throw new UncheckedOrtException(e);}
+ }
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java
new file mode 100644
index 00000000000..1f2c8ba2cf7
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java
@@ -0,0 +1,15 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OrtException;
+
+/**
+ * @author bjorncs
+ */
+public class UncheckedOrtException extends RuntimeException {
+
+ public UncheckedOrtException(Throwable e) { super(e.getMessage(), e); }
+
+ @Override public synchronized OrtException getCause() { return (OrtException) super.getCause(); }
+}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index b06a54d68bb..329b87cacd1 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -1,13 +1,12 @@
package ai.vespa.embedding;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
-import java.lang.IllegalArgumentException;
import java.util.List;
import static org.junit.Assert.assertEquals;
@@ -20,12 +19,12 @@ public class BertBaseEmbedderTest {
public void testEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
String modelPath = "src/test/models/onnx/transformer/dummy_transformer.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
- BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
+ BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");
List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer
@@ -39,13 +38,13 @@ public class BertBaseEmbedderTest {
public void testEmbedderWithoutTokenTypeIdsName() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
builder.transformerTokenTypeIds("");
- BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
+ BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");
List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer
@@ -59,14 +58,18 @@ public class BertBaseEmbedderTest {
public void testEmbedderWithoutTokenTypeIdsNameButWithConfig() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
// we did not configured BertBaseEmbedder to accept missing token type ids
// so we expect ctor to throw
- assertThrows(IllegalArgumentException.class, () -> { new BertBaseEmbedder(builder.build()); });
+ assertThrows(IllegalArgumentException.class, () -> { newBertBaseEmbedder(builder.build()); });
+ }
+
+ private static BertBaseEmbedder newBertBaseEmbedder(BertBaseEmbedderConfig cfg) {
+ return new BertBaseEmbedder(new OnnxRuntime(), cfg);
}
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
index c67b6b0dcab..0ff9acc9a69 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
@@ -1,19 +1,5 @@
package ai.vespa.embedding.huggingface;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
-import com.yahoo.config.ModelReference;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.junit.Test;
-
-import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
-
-import java.io.IOException;
-import java.util.List;
-
-import static org.junit.Assume.assumeTrue;
-import static org.junit.Assert.assertEquals;
-
public class HuggingFaceEmbedderTest {
/*
@Test
@@ -21,7 +7,7 @@ public class HuggingFaceEmbedderTest {
String modelPath = "src/test/models/hf/model.onnx";
String tokenizerPath = "src/test/models/hf/tokenizer.json";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
builder.tokenizerPath(ModelReference.valueOf(tokenizerPath));
diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
index 733430aa10d..c22902b344f 100644
--- a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
+++ b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
@@ -1,6 +1,6 @@
package ai.vespa.llm;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.llm.GeneratorConfig;
import org.junit.Test;
@@ -15,13 +15,13 @@ public class GeneratorTest {
String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model";
String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx";
String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(encoderModelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath));
GeneratorConfig.Builder builder = new GeneratorConfig.Builder();
builder.tokenizerModel(ModelReference.valueOf(vocabPath));
builder.encoderModel(ModelReference.valueOf(encoderModelPath));
builder.decoderModel(ModelReference.valueOf(decoderModelPath));
- Generator generator = new Generator(builder.build());
+ Generator generator = newGenerator(builder.build());
GeneratorOptions options = new GeneratorOptions();
options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY);
@@ -33,4 +33,8 @@ public class GeneratorTest {
assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result);
}
+ private static Generator newGenerator(GeneratorConfig cfg) {
+ return new Generator(new OnnxRuntime(), cfg);
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
deleted file mode 100644
index acce660f466..00000000000
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.modelintegration.evaluator;
-
-import org.junit.jupiter.api.Test;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotSame;
-import static org.junit.jupiter.api.Assertions.assertSame;
-import static org.mockito.Mockito.mock;
-
-/**
- * @author bjorncs
- */
-class OnnxEvaluatorCacheTest {
-
- @Test
- void reuses_instance_while_in_use() {
- var cache = new OnnxEvaluatorCache((__, ___) -> mock(OnnxEvaluator.class));
- var referencedEvaluator1 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
- var referencedEvaluator2 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
- var referencedEvaluator3 = cache.evaluatorOf("model2", new OnnxEvaluatorOptions());
- assertSame(referencedEvaluator1.evaluator(), referencedEvaluator2.evaluator());
- assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator3.evaluator());
- assertEquals(2, cache.size());
- referencedEvaluator1.close();
- referencedEvaluator2.close();
- assertEquals(1, cache.size());
- referencedEvaluator3.close();
- assertEquals(0, cache.size());
- var referencedEvaluator4 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
- assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator4.evaluator());
- assertEquals(1, cache.size());
- referencedEvaluator4.close();
- assertEquals(0, cache.size());
- }
-
-} \ No newline at end of file
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 83f355821e5..5aba54de11b 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -5,23 +5,31 @@ package ai.vespa.modelintegration.evaluator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import org.junit.jupiter.api.BeforeAll;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assume.assumeTrue;
+import static org.junit.Assume.assumeNotNull;
/**
* @author lesters
*/
public class OnnxEvaluatorTest {
+ private static OnnxRuntime runtime;
+
+ @BeforeAll
+ public static void beforeAll() {
+ if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime();
+ }
+
@Test
public void testSimpleModel() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx");
+ assumeNotNull(runtime);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx");
// Input types
Map<String, TensorType> inputTypes = evaluator.getInputInfo();
@@ -45,8 +53,8 @@ public class OnnxEvaluatorTest {
@Test
public void testBatchDimension() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx");
+ assumeNotNull(runtime);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx");
// Input types
Map<String, TensorType> inputTypes = evaluator.getInputInfo();
@@ -64,7 +72,7 @@ public class OnnxEvaluatorTest {
@Test
public void testMatMul() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeNotNull(runtime);
String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]";
String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]";
String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]";
@@ -73,7 +81,7 @@ public class OnnxEvaluatorTest {
@Test
public void testTypes() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeNotNull(runtime);
assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
@@ -86,8 +94,8 @@ public class OnnxEvaluatorTest {
@Test
public void testNotIdentifiers() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/badnames.onnx");
+ assumeNotNull(runtime);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx");
var inputInfo = evaluator.getInputInfo();
var outputInfo = evaluator.getOutputInfo();
for (var entry : inputInfo.entrySet()) {
@@ -152,7 +160,7 @@ public class OnnxEvaluatorTest {
}
private void assertEvaluate(String model, String output, String... input) {
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/" + model);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model);
Map<String, Tensor> inputs = new HashMap<>();
for (int i = 0; i < input.length; ++i) {
inputs.put("input" + (i+1), Tensor.from(input[i]));
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
new file mode 100644
index 00000000000..81b1237e770
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
@@ -0,0 +1,48 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+
+/**
+ * @author bjorncs
+ */
+class OnnxRuntimeTest {
+
+ @Test
+ void reuses_sessions_while_active() throws OrtException {
+ var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class));
+ var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
+ var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
+ var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false);
+ assertSame(session1.instance(), session2.instance());
+ assertNotSame(session1.instance(), session3.instance());
+ assertEquals(2, runtime.sessionsCached());
+
+ session1.close();
+ session2.close();
+ assertEquals(1, runtime.sessionsCached());
+ verify(session1.instance()).close();
+ verify(session3.instance(), never()).close();
+
+ session3.close();
+ assertEquals(0, runtime.sessionsCached());
+ verify(session3.instance()).close();
+
+ var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
+ assertNotSame(session1.instance(), session4.instance());
+ assertEquals(1, runtime.sessionsCached());
+ session4.close();
+ assertEquals(0, runtime.sessionsCached());
+ verify(session4.instance()).close();
+ }
+} \ No newline at end of file