diff options
28 files changed, 380 insertions, 274 deletions
diff --git a/application/pom.xml b/application/pom.xml index 2193f0fe2e3..236bcb6d81a 100644 --- a/application/pom.xml +++ b/application/pom.xml @@ -182,6 +182,12 @@ <artifactId>junit-jupiter-engine</artifactId> <scope>test</scope> </dependency> + <dependency> + <!-- Required for ContainerModelEvaluationTest --> + <groupId>com.microsoft.onnxruntime</groupId> + <artifactId>onnxruntime</artifactId> + <scope>test</scope> + </dependency> </dependencies> <build> diff --git a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java index f838d7a5481..cd5fd42a81a 100644 --- a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java +++ b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.application.container; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.application.Application; import com.yahoo.application.Networking; import com.yahoo.application.container.handler.Request; @@ -40,7 +40,7 @@ public class ContainerModelEvaluationTest { @Test void testCreateApplicationInstanceWithModelEvaluation() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); try (Application application = Application.fromApplicationPackage(new File("src/test/app-packages/model-evaluation"), Networking.disable)) { @@ -54,17 +54,17 @@ public class ContainerModelEvaluationTest { } { - String expected = "{\"cells\":[{\"address\":{},\"value\":2.496898}]}"; + String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":2.496898}]}"; assertResponse("http://localhost/model-evaluation/v1/xgboost_xgboost_2_2/eval?format.tensors=long", expected, jdisc); } { - String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}"; + String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}"; assertResponse("http://localhost/model-evaluation/v1/lightgbm_regression/eval?format.tensors=long", expected, jdisc); } { - String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.3671652674674988}]}"; + String expected = "{\"type\":\"tensor<float>(d0[3])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.36716532707214355}]}"; assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc); assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default.output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc); assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default/output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java index 57110b2431e..3d9a8441ed5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java @@ -44,7 +44,7 @@ public class ContainerModelEvaluation implements public ContainerModelEvaluation(ApplicationContainerCluster cluster, RankProfileList rankProfileList) { this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null"); cluster.addSimpleComponent(EVALUATOR_NAME, null, EVALUATION_BUNDLE_NAME); - cluster.addSimpleComponent("ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache", null, INTEGRATION_BUNDLE_NAME); + cluster.addSimpleComponent("ai.vespa.modelintegration.evaluator.OnnxRuntime", null, INTEGRATION_BUNDLE_NAME); cluster.addComponent(ContainerModelEvaluation.getHandler()); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java index 063f8f3109e..5b6c7b97875 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container.ml; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import ai.vespa.models.evaluation.FunctionEvaluator; import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.tensor.Tensor; @@ -21,7 +21,7 @@ public class ModelsEvaluatorTest { void testModelsEvaluator() { // Assumption fails but test passes on Intel macs // Assumption fails and test fails on ARM64 - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); ModelsEvaluator modelsEvaluator = ModelsEvaluatorTester.create("src/test/cfg/application/stateless_eval"); assertEquals(3, modelsEvaluator.models().size()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index caf0d22d44e..fc70a65b394 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; import ai.vespa.models.handler.ModelsEvaluationHandler; @@ -27,7 +27,10 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue; /** @@ -60,7 +63,7 @@ public class ModelEvaluationTest { @Test void testMl_serving() throws IOException { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); Path appDir = Path.fromString("src/test/cfg/application/ml_serving"); Path storedAppDir = appDir.append("copy"); try { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java index a731e9c7ccc..b0fe2c09227 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import ai.vespa.models.evaluation.FunctionEvaluator; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; @@ -45,7 +45,7 @@ public class StatelessOnnxEvaluationTest { @Test void testStatelessOnnxModelNameCollision() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); Path appDir = Path.fromString("src/test/cfg/application/onnx_name_collision"); try { ImportedModelTester tester = new ImportedModelTester("onnx", appDir); @@ -66,7 +66,7 @@ public class StatelessOnnxEvaluationTest { @Test void testStatelessOnnxModelEvaluation() throws Exception { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); Path appDir = Path.fromString("src/test/cfg/application/onnx"); Path storedAppDir = appDir.append("copy"); try { @@ -91,7 +91,7 @@ public class StatelessOnnxEvaluationTest { @Test void testStatelessOnnxModelEvaluationWithGpu() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); NodeResources resources = new NodeResources(4, 16, 125, 10, NodeResources.DiskSpeed.fast, NodeResources.StorageType.local, NodeResources.Architecture.x86_64, diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index 9fd25ac115b..667712d0daa 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -70,7 +70,7 @@ "public" ], "methods" : [ - "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)", + "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxRuntime)", "public void <init>(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig, com.yahoo.filedistribution.fileacquirer.FileAcquirer)", "public void <init>(ai.vespa.models.evaluation.RankProfilesConfigImporter, com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)", "public void <init>(java.util.Map)", @@ -88,7 +88,7 @@ "public" ], "methods" : [ - "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache)", + "public void <init>(com.yahoo.filedistribution.fileacquirer.FileAcquirer, ai.vespa.modelintegration.evaluator.OnnxRuntime)", "public java.util.Map importFrom(com.yahoo.vespa.config.search.RankProfilesConfig, com.yahoo.vespa.config.search.core.RankingConstantsConfig, com.yahoo.vespa.config.search.core.RankingExpressionsConfig, com.yahoo.vespa.config.search.core.OnnxModelsConfig)", "protected final java.lang.String readExpressionFromFile(java.io.File)", "protected com.yahoo.searchlib.rankingexpression.RankingExpression readExpressionFromFile(java.lang.String, com.yahoo.config.FileReference)", diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 74233853ae9..fd5306f9add 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; import com.yahoo.component.AbstractComponent; @@ -32,8 +32,8 @@ public class ModelsEvaluator extends AbstractComponent { RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer, - OnnxEvaluatorCache cache) { - this(new RankProfilesConfigImporter(fileAcquirer, cache), config, constantsConfig, expressionsConfig, onnxModelsConfig); + OnnxRuntime onnx) { + this(new RankProfilesConfigImporter(fileAcquirer, onnx), config, constantsConfig, expressionsConfig, onnxModelsConfig); } public ModelsEvaluator(RankProfilesConfig config, @@ -41,7 +41,7 @@ public class ModelsEvaluator extends AbstractComponent { RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig, FileAcquirer fileAcquirer) { - this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxEvaluatorCache()); + this(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer, new OnnxRuntime()); } public ModelsEvaluator(RankProfilesConfigImporter importer, diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index 73c5eb36539..cf97c20e881 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -2,8 +2,8 @@ package ai.vespa.models.evaluation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; -import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -48,12 +48,12 @@ class OnnxModel implements AutoCloseable { final List<OutputSpec> outputSpecs = new ArrayList<>(); void addInputMapping(String onnxName, String source) { - if (referencedEvaluator != null) + if (evaluator != null) throw new IllegalStateException("input mapping must be added before load()"); inputSpecs.add(new InputSpec(onnxName, source)); } void addOutputMapping(String onnxName, String outputAs) { - if (referencedEvaluator != null) + if (evaluator != null) throw new IllegalStateException("output mapping must be added before load()"); outputSpecs.add(new OutputSpec(onnxName, outputAs)); } @@ -61,15 +61,15 @@ class OnnxModel implements AutoCloseable { private final String name; private final File modelFile; private final OnnxEvaluatorOptions options; - private final OnnxEvaluatorCache cache; + private final OnnxRuntime onnx; - private OnnxEvaluatorCache.ReferencedEvaluator referencedEvaluator; + private OnnxEvaluator evaluator; - OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxEvaluatorCache cache) { + OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxRuntime onnx) { this.name = name; this.modelFile = modelFile; this.options = options; - this.cache = cache; + this.onnx = onnx; } public String name() { @@ -77,8 +77,8 @@ class OnnxModel implements AutoCloseable { } public void load() { - if (referencedEvaluator == null) { - referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options); + if (evaluator == null) { + evaluator = onnx.evaluatorOf(modelFile.getPath(), options); fillInputTypes(evaluator().getInputs()); fillOutputTypes(evaluator().getOutputs()); } @@ -178,11 +178,11 @@ class OnnxModel implements AutoCloseable { } private OnnxEvaluator evaluator() { - if (referencedEvaluator == null) { + if (evaluator == null) { throw new IllegalStateException("ONNX model has not been loaded."); } - return referencedEvaluator.evaluator(); + return evaluator; } - @Override public void close() { referencedEvaluator.close(); } + @Override public void close() { evaluator.close(); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 6148287a536..8c520e87001 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,8 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.collections.Pair; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -47,11 +47,11 @@ import java.util.regex.Pattern; public class RankProfilesConfigImporter { private final FileAcquirer fileAcquirer; - private final OnnxEvaluatorCache onnxEvaluatorCache; + private final OnnxRuntime onnx; - public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxEvaluatorCache onnxEvaluatorCache) { + public RankProfilesConfigImporter(FileAcquirer fileAcquirer, OnnxRuntime onnx) { this.fileAcquirer = fileAcquirer; - this.onnxEvaluatorCache = onnxEvaluatorCache; + this.onnx = onnx; } /** @@ -198,7 +198,7 @@ public class RankProfilesConfigImporter { options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required()); - var m = new OnnxModel(name, file, options, onnxEvaluatorCache); + var m = new OnnxModel(name, file, options, onnx); for (var spec : onnxModelConfig.input()) { m.addInputMapping(spec.name(), spec.source()); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java index 992dae22aaf..0bee33be3cc 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; @@ -30,7 +30,7 @@ public class OnnxEvaluatorTest { @Test public void testOnnxEvaluation() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); ModelsEvaluator models = createModels(); assertTrue(models.models().containsKey("add_mul")); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java index bfba5ae24c4..0dd3bd29a2c 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import ai.vespa.modelintegration.evaluator.OnnxEvaluatorCache; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.io.GrowableByteBuffer; @@ -25,7 +25,7 @@ public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesC private final Path constantsPath; public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) { - super(fileAcquirer, new OnnxEvaluatorCache()); + super(fileAcquirer, new OnnxRuntime()); this.constantsPath = constantsPath; } diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 9b2b793212b..14da15f60d0 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.handler; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import ai.vespa.models.evaluation.ModelsEvaluator; import ai.vespa.models.evaluation.RankProfilesConfigImporterWithMockedConstants; import com.yahoo.config.subscription.ConfigGetter; @@ -323,7 +323,7 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSavedEvaluateSpecificFunction() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); Map<String, String> properties = new HashMap<>(); properties.put("input", inputTensor()); properties.put("format.tensors", "long"); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java index 86f56e14e2d..856031da72f 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.handler; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -27,7 +27,7 @@ public class OnnxEvaluationHandlerTest { @BeforeClass static public void setUp() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); handler = new HandlerTester(createModels()); } diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 8f26758cf65..9bb60827a68 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -69,6 +69,12 @@ <scope>provided</scope> </dependency> <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>component</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> <groupId>net.java.dev.jna</groupId> <artifactId>jna</artifactId> <scope>provided</scope> diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index 002350ce3cf..b0b4f871163 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -2,8 +2,10 @@ package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; -import com.yahoo.embedding.BertBaseEmbedderConfig; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.annotation.Inject; +import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.jdisc.AbstractResource; import com.yahoo.language.process.Embedder; import com.yahoo.language.wordpiece.WordPieceEmbedder; import com.yahoo.tensor.IndexedTensor; @@ -28,7 +30,7 @@ import java.util.Map; * * @author lesters */ -public class BertBaseEmbedder implements Embedder { +public class BertBaseEmbedder extends AbstractResource implements Embedder { private final static int TOKEN_CLS = 101; // [CLS] private final static int TOKEN_SEP = 102; // [SEP] @@ -44,7 +46,7 @@ public class BertBaseEmbedder implements Embedder { private final OnnxEvaluator evaluator; @Inject - public BertBaseEmbedder(BertBaseEmbedderConfig config) { + public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) { maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); @@ -58,7 +60,7 @@ public class BertBaseEmbedder implements Embedder { options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads())); tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build(); - evaluator = new OnnxEvaluator(config.transformerModel().toString(), options); + this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options); validateModel(); } @@ -100,6 +102,8 @@ public class BertBaseEmbedder implements Embedder { return embedTokens(tokens, type); } + @Override protected void destroy() { evaluator.close(); } + Tensor embedTokens(List<Integer> tokens, TensorType type) { Tensor inputSequence = createTensorRepresentation(tokens, "d1"); Tensor attentionMask = createAttentionMask(inputSequence); diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 81150fe99b0..bad4bb5c9b3 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -3,22 +3,25 @@ package ai.vespa.embedding.huggingface; import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.annotation.Inject; +import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import com.yahoo.jdisc.AbstractResource; import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.io.*; +import java.io.IOException; import java.nio.file.Paths; -import java.util.*; +import java.util.Arrays; +import java.util.List; +import java.util.Map; -import org.slf4j.LoggerFactory; -import org.slf4j.Logger; - -public class HuggingFaceEmbedder implements Embedder { +public class HuggingFaceEmbedder extends AbstractResource implements Embedder { private static final Logger LOG = LoggerFactory.getLogger(HuggingFaceEmbedder.class.getName()); @@ -30,7 +33,7 @@ public class HuggingFaceEmbedder implements Embedder { private final OnnxEvaluator evaluator; @Inject - public HuggingFaceEmbedder(HuggingFaceEmbedderConfig config) throws IOException { + public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException { maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); @@ -48,7 +51,7 @@ public class HuggingFaceEmbedder implements Embedder { LOG.info("Could not initialize the tokenizer"); throw new IOException("Could not initialize the tokenizer."); } - evaluator = new OnnxEvaluator(config.transformerModel().toString()); + evaluator = onnx.evaluatorOf(config.transformerModel().toString()); validateModel(); } @@ -83,6 +86,8 @@ public class HuggingFaceEmbedder implements Embedder { return tokenIds; } + @Override protected void destroy() { evaluator.close(); } + public List<Integer> longToInteger(long[] values) { return Arrays.stream(values) .boxed().map(Long::intValue) diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java index ed231a5e94c..a08e2006e2c 100644 --- a/model-integration/src/main/java/ai/vespa/llm/Generator.java +++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java @@ -2,7 +2,9 @@ package ai.vespa.llm; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.annotation.Inject; +import com.yahoo.jdisc.AbstractResource; import com.yahoo.language.process.Embedder; import com.yahoo.language.sentencepiece.SentencePieceEmbedder; import com.yahoo.llm.GeneratorConfig; @@ -25,7 +27,7 @@ import java.util.Map; * * @author lesters */ -public class Generator { +public class Generator extends AbstractResource { private final static int TOKEN_EOS = 1; // end of sequence @@ -46,7 +48,7 @@ public class Generator { private final OnnxEvaluator decoder; @Inject - public Generator(GeneratorConfig config) { + public Generator(OnnxRuntime onnx, GeneratorConfig config) { // Set up tokenizer tokenizer = new SentencePieceEmbedder.Builder(config.tokenizerModel().toString()).build(); tokenizerMaxTokens = config.tokenizerMaxTokens(); @@ -61,7 +63,7 @@ public class Generator { encoderOptions.setInterOpThreads(modifyThreadCount(config.encoderOnnxInterOpThreads())); encoderOptions.setIntraOpThreads(modifyThreadCount(config.encoderOnnxIntraOpThreads())); - encoder = new OnnxEvaluator(config.encoderModel().toString(), encoderOptions); + encoder = onnx.evaluatorOf(config.encoderModel().toString(), encoderOptions); // Set up decoder decoderInputIdsName = config.decoderModelInputIdsName(); @@ -74,7 +76,7 @@ public class Generator { decoderOptions.setInterOpThreads(modifyThreadCount(config.decoderOnnxInterOpThreads())); decoderOptions.setIntraOpThreads(modifyThreadCount(config.decoderOnnxIntraOpThreads())); - decoder = new OnnxEvaluator(config.decoderModel().toString(), decoderOptions); + decoder = onnx.evaluatorOf(config.decoderModel().toString(), decoderOptions); validateModels(); } @@ -99,6 +101,8 @@ public class Generator { return generate(prompt, new GeneratorOptions()); } + @Override protected void destroy() { encoder.close(); decoder.close(); } + private String generateNotImplemented(GeneratorOptions options) { throw new UnsupportedOperationException("Search method '" + options.getSearchMethod() + "' is currently not implemented"); } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index c2d97e37074..7cdc27b6d63 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -5,9 +5,9 @@ package ai.vespa.modelintegration.evaluator; import ai.onnxruntime.NodeInfo; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; -import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; +import ai.vespa.modelintegration.evaluator.OnnxRuntime.ReferencedOrtSession; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -15,6 +15,8 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import static ai.vespa.modelintegration.evaluator.OnnxRuntime.isCudaError; + /** * Evaluates an ONNX Model by deferring to ONNX Runtime. @@ -23,24 +25,18 @@ import java.util.Map; */ public class OnnxEvaluator implements AutoCloseable { - private final OrtEnvironment environment; - private final OrtSession session; - - public OnnxEvaluator(String modelPath) { - this(modelPath, null); - } + private final ReferencedOrtSession session; - public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) { - environment = OrtEnvironment.getEnvironment(); - session = createSession(modelPath, environment, options, true); + OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options, OnnxRuntime runtime) { + session = createSession(modelPath, runtime, options, true); } public Tensor evaluate(Map<String, Tensor> inputs, String output) { Map<String, OnnxTensor> onnxInputs = null; try { output = mapToInternalName(output); - onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); - try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) { + onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), session.instance()); + try (OrtSession.Result result = session.instance().run(onnxInputs, Collections.singleton(output))) { return TensorConverter.toVespaTensor(result.get(0)); } } catch (OrtException e) { @@ -55,9 +51,9 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) { Map<String, OnnxTensor> onnxInputs = null; try { - onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); + onnxInputs = TensorConverter.toOnnxTensors(inputs, OnnxRuntime.ortEnvironment(), session.instance()); Map<String, Tensor> outputs = new HashMap<>(); - try (OrtSession.Result result = session.run(onnxInputs)) { + try (OrtSession.Result result = session.instance().run(onnxInputs)) { for (Map.Entry<String, OnnxValue> output : result) { String mapped = TensorConverter.asValidName(output.getKey()); outputs.put(mapped, TensorConverter.toVespaTensor(output.getValue())); @@ -88,7 +84,7 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, IdAndType> getInputs() { try { - return toSpecMap(session.getInputInfo()); + return toSpecMap(session.instance().getInputInfo()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } @@ -96,7 +92,7 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, IdAndType> getOutputs() { try { - return toSpecMap(session.getOutputInfo()); + return toSpecMap(session.instance().getOutputInfo()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } @@ -104,7 +100,7 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, TensorType> getInputInfo() { try { - return TensorConverter.toVespaTypes(session.getInputInfo()); + return TensorConverter.toVespaTypes(session.instance().getInputInfo()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } @@ -112,7 +108,7 @@ public class OnnxEvaluator implements AutoCloseable { public Map<String, TensorType> getOutputInfo() { try { - return TensorConverter.toVespaTypes(session.getOutputInfo()); + return TensorConverter.toVespaTypes(session.instance().getOutputInfo()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } @@ -122,26 +118,26 @@ public class OnnxEvaluator implements AutoCloseable { public void close() throws IllegalStateException { try { session.close(); - } catch (OrtException e) { + } catch (UncheckedOrtException e) { throw new IllegalStateException("Failed to close ONNX session", e); } catch (IllegalStateException e) { throw new IllegalStateException("Already closed", e); } } - private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) { + private static ReferencedOrtSession createSession(String modelPath, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) { if (options == null) { options = new OnnxEvaluatorOptions(); } try { - return environment.createSession(modelPath, options.getOptions(tryCuda && options.requestingGpu())); + return runtime.acquireSession(modelPath, options, tryCuda && options.requestingGpu()); } catch (OrtException e) { if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) { throw new IllegalArgumentException("No such file: " + modelPath); } if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) { // Failed in CUDA native code, but GPU device is optional, so we can proceed without it - return createSession(modelPath, environment, options, false); + return createSession(modelPath, runtime, options, false); } if (isCudaError(e)) { throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e); @@ -150,34 +146,8 @@ public class OnnxEvaluator implements AutoCloseable { } } - private static boolean isCudaError(OrtException e) { - return switch (e.getCode()) { - case ORT_FAIL -> e.getMessage().contains("cudaError"); - case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA"); - default -> false; - }; - } - - public static boolean isRuntimeAvailable() { - return isRuntimeAvailable(""); - } - - public static boolean isRuntimeAvailable(String modelPath) { - try { - new OnnxEvaluator(modelPath); - return true; - } catch (IllegalArgumentException e) { - if (e.getMessage().equals("No such file: ")) { - return true; - } - return false; - } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) { - return false; - } - } - private String mapToInternalName(String outputName) throws OrtException { - var info = session.getOutputInfo(); + var info = session.instance().getOutputInfo(); var internalNames = info.keySet(); for (String name : internalNames) { if (name.equals(outputName)) { diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java deleted file mode 100644 index b92ce24a6b4..00000000000 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCache.java +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.modelintegration.evaluator; - -import com.yahoo.jdisc.AbstractResource; -import com.yahoo.jdisc.ReferencedResource; -import com.yahoo.jdisc.ResourceReference; - -import javax.inject.Inject; -import java.util.HashMap; -import java.util.Map; - -/** - * Caches instances of {@link OnnxEvaluator}. - * - * @author bjorncs - */ -public class OnnxEvaluatorCache { - - // For mocking OnnxEvaluator in tests - @FunctionalInterface interface OnnxEvaluatorFactory { OnnxEvaluator create(String path, OnnxEvaluatorOptions opts); } - - private final Object monitor = new Object(); - private final Map<Id, SharedEvaluator> cache = new HashMap<>(); - private final OnnxEvaluatorFactory factory; - - @Inject public OnnxEvaluatorCache() { this(OnnxEvaluator::new); } - - OnnxEvaluatorCache(OnnxEvaluatorFactory factory) { this.factory = factory; } - - public ReferencedEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) { - synchronized (monitor) { - var id = new Id(modelPath, options); - var sharedInstance = cache.get(id); - if (sharedInstance == null) { - return newInstance(id); - } else { - ResourceReference reference; - try { - // refer() may throw if last reference was just released, but instance has not yet been removed from cache - reference = sharedInstance.refer(id); - } catch (IllegalStateException e) { - return newInstance(id); - } - return new ReferencedEvaluator(sharedInstance, reference); - } - } - } - - int size() { return cache.size(); } - - private ReferencedEvaluator newInstance(Id id) { - var evaluator = new SharedEvaluator(id, factory.create(id.modelPath, id.options)); - cache.put(id, evaluator); - var referenced = new ReferencedEvaluator(evaluator, evaluator.refer(id)); - // Release "main" reference to ensure that evaluator is destroyed when last external reference is released - evaluator.release(); - return referenced; - } - - // We assume options are never modified after being passed to cache - record Id(String modelPath, OnnxEvaluatorOptions options) {} - - public class ReferencedEvaluator extends ReferencedResource<SharedEvaluator> { - ReferencedEvaluator(SharedEvaluator resource, ResourceReference reference) { super(resource, reference); } - - public OnnxEvaluator evaluator() { return getResource().instance(); } - } - - public class SharedEvaluator extends AbstractResource { - private final Id id; - private final OnnxEvaluator instance; - - private SharedEvaluator(Id id, OnnxEvaluator instance) { - this.id = id; - this.instance = instance; - } - - public OnnxEvaluator instance() { return instance; } - - @Override - protected void destroy() { - synchronized (OnnxEvaluatorCache.this) { cache.remove(id); } - instance.close(); - } - } - -} diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java new file mode 100644 index 00000000000..42830041c02 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java @@ -0,0 +1,170 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.jdisc.ResourceReference; +import com.yahoo.jdisc.refcount.DebugReferencesWithStack; +import com.yahoo.jdisc.refcount.References; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static com.yahoo.yolean.Exceptions.throwUnchecked; + +/** + * Provides ONNX runtime environment with session management. + * + * @author bjorncs + */ +public class OnnxRuntime extends AbstractComponent { + + // For unit testing + @FunctionalInterface interface OrtSessionFactory { + OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException; + } + + private static final Logger log = Logger.getLogger(OnnxRuntime.class.getName()); + + private static final OrtEnvironmentResult ortEnvironment = getOrtEnvironment(); + private static final OrtSessionFactory defaultFactory = (path, opts) -> ortEnvironment().createSession(path, opts); + + private final Object monitor = new Object(); + private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>(); + private final OrtSessionFactory factory; + + @Inject public OnnxRuntime() { this(defaultFactory); } + + OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; } + + public OnnxEvaluator evaluatorOf(String modelPath) { + return new OnnxEvaluator(modelPath, null, this); + } + + public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) { + return new OnnxEvaluator(modelPath, options, this); + } + + public static OrtEnvironment ortEnvironment() { + if (ortEnvironment.env() != null) return ortEnvironment.env(); + throw throwUnchecked(ortEnvironment.failure()); + } + + @Override + public void deconstruct() { + synchronized (monitor) { + sessions.forEach((id, sharedSession) -> { + int hash = System.identityHashCode(sharedSession.session()); + var refs = sharedSession.references(); + log.warning("Closing leaked session %s (%s) with %d outstanding references:\n%s" + .formatted(id, hash, refs.referenceCount(), refs.currentState())); + try { + sharedSession.session().close(); + } catch (Exception e) { + log.log(Level.WARNING, "Failed to close session %s (%s)".formatted(id, hash), e); + } + }); + sessions.clear(); + } + } + + private static OrtEnvironmentResult getOrtEnvironment() { + try { + return new OrtEnvironmentResult(OrtEnvironment.getEnvironment(), null); + } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) { + log.log(Level.FINE, e, () -> "Failed to load ONNX runtime"); + return new OrtEnvironmentResult(null, e); + } + } + + public static boolean isRuntimeAvailable() { return ortEnvironment.env() != null; } + public static boolean isRuntimeAvailable(String modelPath) { + if (!isRuntimeAvailable()) return false; + try { + // Expensive way of checking if runtime is available as it incurs the cost of loading the model if successful + defaultFactory.create(modelPath, new OnnxEvaluatorOptions().getOptions(false)); + return true; + } catch (OrtException e) { + return e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE; + } catch (UnsatisfiedLinkError | RuntimeException | NoClassDefFoundError e) { + return false; + } + } + + static boolean isCudaError(OrtException e) { + return switch (e.getCode()) { + case ORT_FAIL -> e.getMessage().contains("cudaError"); + case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA"); + default -> false; + }; + } + + ReferencedOrtSession acquireSession(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException { + var sessionId = new OrtSessionId(modelPath, options, loadCuda); + synchronized (monitor) { + var sharedSession = sessions.get(sessionId); + if (sharedSession != null) { + return sharedSession.newReference(); + } + } + + // Note: identical models loaded simultaneously will result in duplicate session instances + var session = factory.create(modelPath, options.getOptions(loadCuda)); + log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(session))); + + var sharedSession = new SharedOrtSession(sessionId, session); + var referencedSession = sharedSession.newReference(); + synchronized (monitor) { sessions.put(sessionId, sharedSession); } + sharedSession.references().release(); // Release initial reference + return referencedSession; + } + + int sessionsCached() { synchronized(monitor) { return sessions.size(); } } + + public static class ReferencedOrtSession implements AutoCloseable { + private final OrtSession instance; + private final ResourceReference ref; + + public ReferencedOrtSession(OrtSession instance, ResourceReference ref) { + this.instance = instance; + this.ref = ref; + } + + public OrtSession instance() { return instance; } + @Override public void close() { ref.close(); } + } + + // Assumes options are never modified after being stored in `onnxSessions` + record OrtSessionId(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) {} + + record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {} + + private class SharedOrtSession { + private final OrtSessionId id; + private final OrtSession session; + private final References refs = new DebugReferencesWithStack(this::close); + + SharedOrtSession(OrtSessionId id, OrtSession session) { + this.id = id; + this.session = session; + } + + ReferencedOrtSession newReference() { return new ReferencedOrtSession(session, refs.refer(id)); } + References references() { return refs; } + OrtSession session() { return session; } + + void close() { + try { + synchronized (OnnxRuntime.this.monitor) { sessions.remove(id); } + log.fine(() -> "Closing session (%s)".formatted(System.identityHashCode(session))); + session.close(); + } catch (OrtException e) { throw new UncheckedOrtException(e);} + } + } +} diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java new file mode 100644 index 00000000000..1f2c8ba2cf7 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/UncheckedOrtException.java @@ -0,0 +1,15 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.OrtException; + +/** + * @author bjorncs + */ +public class UncheckedOrtException extends RuntimeException { + + public UncheckedOrtException(Throwable e) { super(e.getMessage(), e); } + + @Override public synchronized OrtException getCause() { return (OrtException) super.getCause(); } +} diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index b06a54d68bb..329b87cacd1 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -1,13 +1,12 @@ package ai.vespa.embedding; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.ModelReference; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; -import java.lang.IllegalArgumentException; import java.util.List; import static org.junit.Assert.assertEquals; @@ -20,12 +19,12 @@ public class BertBaseEmbedderTest { public void testEmbedder() { String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; String modelPath = "src/test/models/onnx/transformer/dummy_transformer.onnx"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); builder.tokenizerVocab(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); - BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); + BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build()); TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer @@ -39,13 +38,13 @@ public class BertBaseEmbedderTest { public void testEmbedderWithoutTokenTypeIdsName() { String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); builder.tokenizerVocab(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); builder.transformerTokenTypeIds(""); - BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); + BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build()); TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer @@ -59,14 +58,18 @@ public class BertBaseEmbedderTest { public void testEmbedderWithoutTokenTypeIdsNameButWithConfig() { String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); builder.tokenizerVocab(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); // we did not configured BertBaseEmbedder to accept missing token type ids // so we expect ctor to throw - assertThrows(IllegalArgumentException.class, () -> { new BertBaseEmbedder(builder.build()); }); + assertThrows(IllegalArgumentException.class, () -> { newBertBaseEmbedder(builder.build()); }); + } + + private static BertBaseEmbedder newBertBaseEmbedder(BertBaseEmbedderConfig cfg) { + return new BertBaseEmbedder(new OnnxRuntime(), cfg); } } diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java index c67b6b0dcab..0ff9acc9a69 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java @@ -1,19 +1,5 @@ package ai.vespa.embedding.huggingface; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; -import com.yahoo.config.ModelReference; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Test; - -import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; - -import java.io.IOException; -import java.util.List; - -import static org.junit.Assume.assumeTrue; -import static org.junit.Assert.assertEquals; - public class HuggingFaceEmbedderTest { /* @Test @@ -21,7 +7,7 @@ public class HuggingFaceEmbedderTest { String modelPath = "src/test/models/hf/model.onnx"; String tokenizerPath = "src/test/models/hf/tokenizer.json"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder(); builder.tokenizerPath(ModelReference.valueOf(tokenizerPath)); diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java index 733430aa10d..c22902b344f 100644 --- a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java +++ b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java @@ -1,6 +1,6 @@ package ai.vespa.llm; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.ModelReference; import com.yahoo.llm.GeneratorConfig; import org.junit.Test; @@ -15,13 +15,13 @@ public class GeneratorTest { String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model"; String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx"; String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(encoderModelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath)); GeneratorConfig.Builder builder = new GeneratorConfig.Builder(); builder.tokenizerModel(ModelReference.valueOf(vocabPath)); builder.encoderModel(ModelReference.valueOf(encoderModelPath)); builder.decoderModel(ModelReference.valueOf(decoderModelPath)); - Generator generator = new Generator(builder.build()); + Generator generator = newGenerator(builder.build()); GeneratorOptions options = new GeneratorOptions(); options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY); @@ -33,4 +33,8 @@ public class GeneratorTest { assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result); } + private static Generator newGenerator(GeneratorConfig cfg) { + return new Generator(new OnnxRuntime(), cfg); + } + } diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java deleted file mode 100644 index acce660f466..00000000000 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.modelintegration.evaluator; - -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotSame; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.mockito.Mockito.mock; - -/** - * @author bjorncs - */ -class OnnxEvaluatorCacheTest { - - @Test - void reuses_instance_while_in_use() { - var cache = new OnnxEvaluatorCache((__, ___) -> mock(OnnxEvaluator.class)); - var referencedEvaluator1 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); - var referencedEvaluator2 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); - var referencedEvaluator3 = cache.evaluatorOf("model2", new OnnxEvaluatorOptions()); - assertSame(referencedEvaluator1.evaluator(), referencedEvaluator2.evaluator()); - assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator3.evaluator()); - assertEquals(2, cache.size()); - referencedEvaluator1.close(); - referencedEvaluator2.close(); - assertEquals(1, cache.size()); - referencedEvaluator3.close(); - assertEquals(0, cache.size()); - var referencedEvaluator4 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); - assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator4.evaluator()); - assertEquals(1, cache.size()); - referencedEvaluator4.close(); - assertEquals(0, cache.size()); - } - -}
\ No newline at end of file diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index 83f355821e5..5aba54de11b 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -5,23 +5,31 @@ package ai.vespa.modelintegration.evaluator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeTrue; +import static org.junit.Assume.assumeNotNull; /** * @author lesters */ public class OnnxEvaluatorTest { + private static OnnxRuntime runtime; + + @BeforeAll + public static void beforeAll() { + if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime(); + } + @Test public void testSimpleModel() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); - OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx"); + assumeNotNull(runtime); + OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx"); // Input types Map<String, TensorType> inputTypes = evaluator.getInputInfo(); @@ -45,8 +53,8 @@ public class OnnxEvaluatorTest { @Test public void testBatchDimension() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); - OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx"); + assumeNotNull(runtime); + OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx"); // Input types Map<String, TensorType> inputTypes = evaluator.getInputInfo(); @@ -64,7 +72,7 @@ public class OnnxEvaluatorTest { @Test public void testMatMul() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeNotNull(runtime); String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]"; String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]"; String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]"; @@ -73,7 +81,7 @@ public class OnnxEvaluatorTest { @Test public void testTypes() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeNotNull(runtime); assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]"); @@ -86,8 +94,8 @@ public class OnnxEvaluatorTest { @Test public void testNotIdentifiers() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); - OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/badnames.onnx"); + assumeNotNull(runtime); + OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx"); var inputInfo = evaluator.getInputInfo(); var outputInfo = evaluator.getOutputInfo(); for (var entry : inputInfo.entrySet()) { @@ -152,7 +160,7 @@ public class OnnxEvaluatorTest { } private void assertEvaluate(String model, String output, String... input) { - OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/" + model); + OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model); Map<String, Tensor> inputs = new HashMap<>(); for (int i = 0; i < input.length; ++i) { inputs.put("input" + (i+1), Tensor.from(input[i])); diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java new file mode 100644 index 00000000000..81b1237e770 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java @@ -0,0 +1,48 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +/** + * @author bjorncs + */ +class OnnxRuntimeTest { + + @Test + void reuses_sessions_while_active() throws OrtException { + var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class)); + var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); + var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); + var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false); + assertSame(session1.instance(), session2.instance()); + assertNotSame(session1.instance(), session3.instance()); + assertEquals(2, runtime.sessionsCached()); + + session1.close(); + session2.close(); + assertEquals(1, runtime.sessionsCached()); + verify(session1.instance()).close(); + verify(session3.instance(), never()).close(); + + session3.close(); + assertEquals(0, runtime.sessionsCached()); + verify(session3.instance()).close(); + + var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); + assertNotSame(session1.instance(), session4.instance()); + assertEquals(1, runtime.sessionsCached()); + session4.close(); + assertEquals(0, runtime.sessionsCached()); + verify(session4.instance()).close(); + } +}
\ No newline at end of file |