diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-02-27 17:02:23 +0100 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-02-27 18:13:08 +0100 |
commit | 5271b5d7241aa2aa2538b2072b8cae9b8f3d689a (patch) | |
tree | 12f025b12e86e5f9490b74dd2cae68283f779e67 /model-integration/src/main/java/ai | |
parent | 6b40c6053b8542ae20a5bbe669f84f2d478fd697 (diff) |
Replace `OnnxEvaluatorCache` with OnnxRuntime
Require an `OnnxRuntime` instance to create `OnnxEvaluator` instances.
Cache underlying `OrtSession` instead of `OnnxEvaluator`.
Move static helpers for checking Onnx runtime availability from `OnnxEvaluator` to `OnnxRuntime`.
Diffstat (limited to 'model-integration/src/main/java/ai')
7 files changed, 234 insertions, 154 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index 002350ce3cf..b0b4f871163 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -2,8 +2,10 @@ package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; -import com.yahoo.embedding.BertBaseEmbedderConfig; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.annotation.Inject; +import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.jdisc.AbstractResource; import com.yahoo.language.process.Embedder; import com.yahoo.language.wordpiece.WordPieceEmbedder; import com.yahoo.tensor.IndexedTensor; @@ -28,7 +30,7 @@ import java.util.Map; * * @author lesters */ -public class BertBaseEmbedder implements Embedder { +public class BertBaseEmbedder extends AbstractResource implements Embedder { private final static int TOKEN_CLS = 101; // [CLS] private final static int TOKEN_SEP = 102; // [SEP] @@ -44,7 +46,7 @@ public class BertBaseEmbedder implements Embedder { private final OnnxEvaluator evaluator; @Inject - public BertBaseEmbedder(BertBaseEmbedderConfig config) { + public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) { maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); @@ -58,7 +60,7 @@ public class BertBaseEmbedder implements Embedder { options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads())); tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build(); - evaluator = new OnnxEvaluator(config.transformerModel().toString(), options); + this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options); validateModel(); } @@ -100,6 +102,8 @@ public class BertBaseEmbedder implements Embedder { return embedTokens(tokens, type); } + @Override protected void destroy() { evaluator.close(); } + Tensor embedTokens(List<Integer> tokens, TensorType type) { Tensor inputSequence = createTensorRepresentation(tokens, "d1"); Tensor attentionMask = createAttentionMask(inputSequence); diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 81150fe99b0..bad4bb5c9b3 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -3,22 +3,25 @@ package ai.vespa.embedding.huggingface; import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.annotation.Inject; +import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import com.yahoo.jdisc.AbstractResource; import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.io.*; +import java.io.IOException; import java.nio.file.Paths; -import java.util.*; +import java.util.Arrays; +import java.util.List; +import java.util.Map; -import org.slf4j.LoggerFactory; -import org.slf4j.Logger; - -public class HuggingFaceEmbedder implements Embedder { +public class HuggingFaceEmbedder extends AbstractResource implements Embedder { private static final Logger LOG = LoggerFactory.getLogger(HuggingFaceEmbedder.class.getName()); @@ -30,7 +33,7 @@ public class HuggingFaceEmbedder implements Embedder { private final OnnxEvaluator evaluator; @Inject - public HuggingFaceEmbedder(HuggingFaceEmbedderConfig config) throws IOException { + public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException { maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); @@ -48,7 +51,7 @@ public class HuggingFaceEmbedder implements Embedder { LOG.info("Could not initialize the tokenizer"); throw new IOException("Could not initialize the tokenizer."); } - evaluator = new OnnxEvaluator(config.transformerModel().toString()); + evaluator = onnx.evaluatorOf(config.transformerModel().toString()); validateModel(); } @@ -83,6 +86,8 @@ public class HuggingFaceEmbedder implements Embedder { return tokenIds; } + @Override protected void destroy() { evaluator.close(); } + public List<Integer> longToInteger(long[] values) { return Arrays.stream(values) .boxed().map(Long::intValue) diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java index ed231a5e94c..a08e2006e2c 100644 --- a/model-integration/src/main/java/ai/vespa/llm/Generator.java +++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java @@ -2,7 +2,9 @@ package ai.vespa.llm; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.annotation.Inject; +import com.yahoo.jdisc.AbstractResource; import com.yahoo.language.process.Embedder; import com.yahoo.language.sentencepiece.SentencePieceEmbedder; import com.yahoo.llm.GeneratorConfig; @@ -25,7 +27,7 @@ import java.util.Map; * * @author lesters */ -public class Generator { +public class Generator extends AbstractResource { private final static int TOKEN_EOS = 1; // end of sequence @@ -46,7 +48,7 @@ public class Generator { private final OnnxEvaluator decoder; @Inject - public Generator(GeneratorConfig config) { + public Generator(OnnxRuntime onnx, GeneratorConfig config) { // Set up tokenizer tokenizer = new SentencePieceEmbedder.Builder(config.tokenizerModel().toString()).build(); tokenizerMaxTokens = config.tokenizerMaxTokens(); @@ -61,7 +63,7 @@ public class Generator { encoderOptions.setInterOpThreads(modifyThreadCount(config.encoderOnnxInterOpThreads())); encoderOptions.setIntraOpThreads(modifyThreadCount(config.encoderOnnxIntraOpThreads())); - encoder = new OnnxEvaluator(config.encoderModel().toString(), encoderOptions); + encoder = onnx.evaluatorOf(config.encoderModel().toString(), encoderOptions); // Set up decoder decoderInputIdsName = config.decoderModelInputIdsName(); @@ -74,7 +76,7 @@ public class Generator { decoderOptions.setInterOpThreads(modifyThreadCount(config.decoderOnnxInterOpThreads())); decoderOptions.setIntraOpThreads(modifyThreadCount(config.decoderOnnxIntraOpThreads())); - decoder = new OnnxEvaluator(config.decoderModel().toString(), decoderOptions); + decoder = onnx.evaluatorOf(config.decoderModel().toString(), decoderOptions); validateModels(); } @@ -99,6 +101,8 @@ public class Generator { return generate(prompt, new GeneratorOptions()); } + @Override protected void destroy() { encoder.close(); decoder.close(); } + private String generateNotImplemented(GeneratorOptions options) { throw new UnsupportedOperationException("Search method '" + options.getSearchMethod() + "' is currently not implemented"); } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index c2d97e37074..7cdc27b6d63 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -5,9 +5,9 @@ package ai.vespa.modelintegration.evaluator; import ai.onnxruntime.NodeInfo; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; +import ai.vespa.modelintegration.evaluator.OnnxRuntime.ReferencedOrtSession; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -15,6 +15,8 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import static ai.vespa.modelintegration.evaluator.OnnxRuntime.isCudaError; + /** * Evaluates an ONNX Model by deferring to ONNX Runtime. @@ -23,24 +25,18 @@ import java.util.Map; */ public class OnnxEvaluator implements AutoCloseable { - private final OrtEnvironment environment; - private final OrtSession session; - - public OnnxEvaluator(String modelPath) { - this(modelPath, null); - } + private final ReferencedOrtSession session; - public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) { - environment = OrtEnvironment.getEnvironment(); - session = createSession(modelPath, environment, options, true); + OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options, OnnxRuntime runtime) { + session = createSession(modelPath, runtime, options, true); } public Tensor evaluate(Map<String, Tensor> inputs, String output) { Map<String, OnnxTensor> onnxInputs = null; try { output = mapToInternalName(output); - onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); - try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) { + onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), session.instance()); + try (OrtSession.Result result = session.instance().run(onnxInputs, Collections.singleton(output))) { return TensorConverter.toVespaTensor(result.get(0)); } } catch (OrtException e) { @@ -55,9 +51,9 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) { Map<String, OnnxTensor> onnxInputs = null; try { - onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); + onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), session.instance()); Map<String, Tensor> outputs = new HashMap<>(); - try (OrtSession.Result result = session.run(onnxInputs)) { + try (OrtSession.Result result = session.instance().run(onnxInputs)) { for (Map.Entry<String, OnnxValue> output : result) { String mapped = TensorConverter.asValidName(output.getKey()); outputs.put(mapped, TensorConverter.toVespaTensor(output.getValue())); @@ -88,7 +84,7 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, IdAndType> getInputs() { try { - return toSpecMap(session.getInputInfo()); + return toSpecMap(session.instance().getInputInfo()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } @@ -96,7 +92,7 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, IdAndType> getOutputs() { try { - return toSpecMap(session.getOutputInfo()); + return toSpecMap(session.instance().getOutputInfo()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } @@ -104,7 +100,7 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, TensorType> getInputInfo() { try { - return TensorConverter.toVespaTypes(session.getInputInfo()); + return TensorConverter.toVespaTypes(session.instance().getInputInfo()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } @@ -112,7 +108,7 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, TensorType> getOutputInfo() { try { - return TensorConverter.toVespaTypes(session.getOutputInfo()); + return TensorConverter.toVespaTypes(session.instance().getOutputInfo()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } @@ -122,26 +118,26 @@ public class OnnxEvaluator implements AutoCloseable { public void close() throws IllegalStateException { try { session.close(); - } catch (OrtException e) { + } catch (UncheckedOrtException e) { throw new IllegalStateException("Failed to close ONNX session", e); } catch (IllegalStateException e) { throw new IllegalStateException("Already closed", e); } } - private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) { + private static ReferencedOrtSession createSession(String modelPath, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) { if (options == null) { options = new OnnxEvaluatorOptions(); } try { - return environment.createSession(modelPath, options.getOptions(tryCuda && options.requestingGpu())); + return runtime.acquireSession(modelPath, options, tryCuda && options.requestingGpu()); } catch (OrtException e) { if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) { throw new IllegalArgumentException("No such file: " + modelPath); } if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) { // Failed in CUDA native code, but GPU device is optional, so we can proceed without it - return createSession(modelPath, environment, options, false); + return createSession(modelPath, runtime, options, false); } if (isCudaError(e)) { throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e); @@ -150,34 +146,8 @@ public class OnnxEvaluator implements AutoCloseable { } } - private static boolean isCudaError(OrtException e) { - return switch (e.getCode()) { - case ORT_FAIL -> e.getMessage().contains("cudaError"); - case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA"); - default -> false; - }; - } - - public static boolean isRuntimeAvailable() { - return isRuntimeAvailable(""); - } - - public static boolean isRuntimeAvailable(String modelPath) { - try { - new OnnxEvaluator(modelPath); - return true; - } catch (IllegalArgumentException e) { - if (e.getMessage().equals("No such file: ")) { - return true; - } - return false; - } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) { - return false; - } - } - private String mapToInternalName(String outputName) throws OrtException { - var info = session.getOutputInfo(); + var info = session.instance().getOutputInfo(); var internalNames = info.keySet(); for (String name : internalNames) { if (name.equals(outputName)) { diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java deleted file mode 100644 index b92ce24a6b4..00000000000 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.modelintegration.evaluator; - -import com.yahoo.jdisc.AbstractResource; -import com.yahoo.jdisc.ReferencedResource; -import com.yahoo.jdisc.ResourceReference; - -import javax.inject.Inject; -import java.util.HashMap; -import java.util.Map; - -/** - * Caches instances of {@link OnnxEvaluator}. - * - * @author bjorncs - */ -public class OnnxEvaluatorCache { - - // For mocking OnnxEvaluator in tests - @FunctionalInterface interface OnnxEvaluatorFactory { OnnxEvaluator create(String path, OnnxEvaluatorOptions opts); } - - private final Object monitor = new Object(); - private final Map<Id, SharedEvaluator> cache = new HashMap<>(); - private final OnnxEvaluatorFactory factory; - - @Inject public OnnxEvaluatorCache() { this(OnnxEvaluator::new); } - - OnnxEvaluatorCache(OnnxEvaluatorFactory factory) { this.factory = factory; } - - public ReferencedEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) { - synchronized (monitor) { - var id = new Id(modelPath, options); - var sharedInstance = cache.get(id); - if (sharedInstance == null) { - return newInstance(id); - } else { - ResourceReference reference; - try { - // refer() may throw if last reference was just released, but instance has not yet been removed from cache - reference = sharedInstance.refer(id); - } catch (IllegalStateException e) { - return newInstance(id); - } - return new ReferencedEvaluator(sharedInstance, reference); - } - } - } - - int size() { return cache.size(); } - - private ReferencedEvaluator newInstance(Id id) { - var evaluator = new SharedEvaluator(id, factory.create(id.modelPath, id.options)); - cache.put(id, evaluator); - var referenced = new ReferencedEvaluator(evaluator, evaluator.refer(id)); - // Release "main" reference to ensure that evaluator is destroyed when last external reference is released - evaluator.release(); - return referenced; - } - - // We assume options are never modified after being passed to cache - record Id(String modelPath, OnnxEvaluatorOptions options) {} - - public class ReferencedEvaluator extends ReferencedResource<SharedEvaluator> { - ReferencedEvaluator(SharedEvaluator resource, ResourceReference reference) { super(resource, reference); } - - public OnnxEvaluator evaluator() { return getResource().instance(); } - } - - public class SharedEvaluator extends AbstractResource { - private final Id id; - private final OnnxEvaluator instance; - - private SharedEvaluator(Id id, OnnxEvaluator instance) { - this.id = id; - this.instance = instance; - } - - public OnnxEvaluator instance() { return instance; } - - @Override - protected void destroy() { - synchronized (OnnxEvaluatorCache.this) { cache.remove(id); } - instance.close(); - } - } - -} diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java new file mode 100644 index 00000000000..42830041c02 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java @@ -0,0 +1,170 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.jdisc.ResourceReference; +import com.yahoo.jdisc.refcount.DebugReferencesWithStack; +import com.yahoo.jdisc.refcount.References; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static com.yahoo.yolean.Exceptions.throwUnchecked; + +/** + * Provides ONNX runtime environment with session management. + * + * @author bjorncs + */ +public class OnnxRuntime extends AbstractComponent { + + // For unit testing + @FunctionalInterface interface OrtSessionFactory { + OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException; + } + + private static final Logger log = Logger.getLogger(OnnxRuntime.class.getName()); + + private static final OrtEnvironmentResult ortEnvironment = getOrtEnvironment(); + private static final OrtSessionFactory defaultFactory = (path, opts) -> ortEnvironment().createSession(path, opts); + + private final Object monitor = new Object(); + private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>(); + private final OrtSessionFactory factory; + + @Inject public OnnxRuntime() { this(defaultFactory); } + + OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; } + + public OnnxEvaluator evaluatorOf(String modelPath) { + return new OnnxEvaluator(modelPath, null, this); + } + + public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) { + return new OnnxEvaluator(modelPath, options, this); + } + + public static OrtEnvironment ortEnvironment() { + if (ortEnvironment.env() != null) return ortEnvironment.env(); + throw throwUnchecked(ortEnvironment.failure()); + } + + @Override + public void deconstruct() { + synchronized (monitor) { + sessions.forEach((id, sharedSession) -> { + int hash = System.identityHashCode(sharedSession.session()); + var refs = sharedSession.references(); + log.warning("Closing leaked session %s (%s) with %d outstanding references:\n%s" + .formatted(id, hash, refs.referenceCount(), refs.currentState())); + try { + sharedSession.session().close(); + } catch (Exception e) { + log.log(Level.WARNING, "Failed to close session %s (%s)".formatted(id, hash), e); + } + }); + sessions.clear(); + } + } + + private static OrtEnvironmentResult getOrtEnvironment() { + try { + return new OrtEnvironmentResult(OrtEnvironment.getEnvironment(), null); + } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) { + log.log(Level.FINE, e, () -> "Failed to load ONNX runtime"); + return new OrtEnvironmentResult(null, e); + } + } + + public static boolean isRuntimeAvailable() { return ortEnvironment.env() != null; } + public static boolean isRuntimeAvailable(String modelPath) { + if (!isRuntimeAvailable()) return false; + try { + // Expensive way of checking if runtime is available as it incurs the cost of loading the model if successful + defaultFactory.create(modelPath, new OnnxEvaluatorOptions().getOptions(false)); + return true; + } catch (OrtException e) { + return e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE; + } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) { + return false; + } + } + + static boolean isCudaError(OrtException e) { + return switch (e.getCode()) { + case ORT_FAIL -> e.getMessage().contains("cudaError"); + case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA"); + default -> false; + }; + } + + ReferencedOrtSession acquireSession(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException { + var sessionId = new OrtSessionId(modelPath, options, loadCuda); + synchronized (monitor) { + var sharedSession = sessions.get(sessionId); + if (sharedSession != null) { + return sharedSession.newReference(); + } + } + + // Note: identical models loaded simultaneously will result in duplicate session instances + var session = factory.create(modelPath, options.getOptions(loadCuda)); + log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(session))); + + var sharedSession = new SharedOrtSession(sessionId, session); + var referencedSession = sharedSession.newReference(); + synchronized (monitor) { sessions.put(sessionId, sharedSession); } + sharedSession.references().release(); // Release initial reference + return referencedSession; + } + + int sessionsCached() { synchronized(monitor) { return sessions.size(); } } + + public static class ReferencedOrtSession implements AutoCloseable { + private final OrtSession instance; + private final ResourceReference ref; + + public ReferencedOrtSession(OrtSession instance, ResourceReference ref) { + this.instance = instance; + this.ref = ref; + } + + public OrtSession instance() { return instance; } + @Override public void close() { ref.close(); } + } + + // Assumes options are never modified after being stored in `onnxSessions` + record OrtSessionId(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) {} + + record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {} + + private class SharedOrtSession { + private final OrtSessionId id; + private final OrtSession session; + private final References refs = new DebugReferencesWithStack(this::close); + + SharedOrtSession(OrtSessionId id, OrtSession session) { + this.id = id; + this.session = session; + } + + ReferencedOrtSession newReference() { return new ReferencedOrtSession(session, refs.refer(id)); } + References references() { return refs; } + OrtSession session() { return session; } + + void close() { + try { + synchronized (OnnxRuntime.this.monitor) { sessions.remove(id); } + log.fine(() -> "Closing session (%s)".formatted(System.identityHashCode(session))); + session.close(); + } catch (OrtException e) { throw new UncheckedOrtException(e);} + } + } +} diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java new file mode 100644 index 00000000000..1f2c8ba2cf7 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java @@ -0,0 +1,15 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.OrtException; + +/** + * @author bjorncs + */ +public class UncheckedOrtException extends RuntimeException { + + public UncheckedOrtException(Throwable e) { super(e.getMessage(), e); } + + @Override public synchronized OrtException getCause() { return (OrtException) super.getCause(); } +} |