diff options
12 files changed, 189 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 3c42987512b..d85c0065d84 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -8,6 +8,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Optional; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -20,6 +21,10 @@ public class OnnxModel extends DistributableResource { private Map<String, String> inputMap = new HashMap<>(); private Map<String, String> outputMap = new HashMap<>(); + private String statelessExecutionMode = null; + private Integer statelessInterOpThreads = null; + private Integer statelessIntraOpThreads = null; + public OnnxModel(String name) { super(name); } @@ -79,4 +84,37 @@ public class OnnxModel extends DistributableResource { TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) { return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty; } + + public void setStatelessExecutionMode(String executionMode) { + if ("parallel".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "parallel"; + } else if ("sequential".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "sequential"; + } + } + + public Optional<String> getStatelessExecutionMode() { + return Optional.ofNullable(statelessExecutionMode); + } + + public void setStatelessInterOpThreads(int interOpThreads) { + if (interOpThreads >= 0) { + this.statelessInterOpThreads = interOpThreads; + } + } + + public Optional<Integer> getStatelessInterOpThreads() { + return Optional.ofNullable(statelessInterOpThreads); + } + + public void setStatelessIntraOpThreads(int intraOpThreads) { + if (intraOpThreads >= 0) { + this.statelessIntraOpThreads = intraOpThreads; + } + } + + public Optional<Integer> getStatelessIntraOpThreads() { + return Optional.ofNullable(statelessIntraOpThreads); + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java index cb61b8c9cec..dcbea70a32b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -22,6 +22,7 @@ public class OnnxModels { public OnnxModels(FileRegistry fileRegistry) { this.fileRegistry = fileRegistry; } + public void add(OnnxModel model) { model.validate(); model.register(fileRegistry); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index ad85f68cb8a..8291c69af2f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -130,6 +130,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } + public OnnxModels getOnnxModels() { + return onnxModels; + } + public Map<String, RawRankProfile> getRankProfiles() { return rankProfiles; } @@ -182,6 +186,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ modelBuilder.fileref(model.getFileReference()); model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + if (model.getStatelessExecutionMode().isPresent()) + modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); + if (model.getStatelessInterOpThreads().isPresent()) + modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); + if (model.getStatelessIntraOpThreads().isPresent()) + modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); + builder.model(modelBuilder); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index c318180fd56..c62dee68b2d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -33,6 +33,7 @@ import com.yahoo.config.provision.Zone; import com.yahoo.container.logging.FileConnectionLog; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.search.rendering.RendererRegistry; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.security.X509CertificateUtils; import com.yahoo.text.XML; @@ -557,9 +558,31 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { RankProfileList profiles = context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty; + + Element onnxElement = XML.getChild(modelEvaluationElement, "onnx"); + Element modelsElement = XML.getChild(onnxElement, "models"); + for (Element modelElement : XML.getChildren(modelsElement, "model") ) { + OnnxModel onnxModel = profiles.getOnnxModels().get(modelElement.getAttribute("name")); + if (onnxModel == null) + continue; // Skip if model is not found + onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null)); + onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1)); + onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1)); + } + cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles)); } + private String getStringValue(Element element, String name, String defaultValue) { + Element child = XML.getChild(element, name); + return (child != null) ? child.getTextContent() : defaultValue; + } + + private int getIntValue(Element element, String name, int defaultValue) { + Element child = XML.getChild(element, name); + return (child != null) ? Integer.parseInt(child.getTextContent()) : defaultValue; + } + protected void addModelEvaluationBundles(ApplicationContainerCluster cluster) { /* These bundles are added to all application container clusters, even if they haven't * declared 'model-evaluation' in services.xml, because there are many public API packages diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc index 992689a2189..945deec9f91 100644 --- a/config-model/src/main/resources/schema/containercluster.rnc +++ b/config-model/src/main/resources/schema/containercluster.rnc @@ -105,7 +105,16 @@ ZooKeeper = element zookeeper { } ModelEvaluation = element model-evaluation { - empty + element onnx { + element models { + element model { + attribute name { string } & + element intraop-threads { xsd:nonNegativeInteger }? & + element interop-threads { xsd:nonNegativeInteger }? & + element execution-mode { string "sequential" | string "parallel" }? + }* + }? + }? } Ssl = element ssl { diff --git a/config-model/src/test/cfg/application/onnx/services.xml b/config-model/src/test/cfg/application/onnx/services.xml index 8731558c6f7..1cae9250009 100644 --- a/config-model/src/test/cfg/application/onnx/services.xml +++ b/config-model/src/test/cfg/application/onnx/services.xml @@ -3,7 +3,19 @@ <services version="1.0"> <container version="1.0"> - <model-evaluation/> + <model-evaluation> + <onnx> + <models> + <model name="mul"> + <intraop-threads>2</intraop-threads> + </model> + <model name="non-existent-model"> + <interop-threads>400</interop-threads> + <execution-mode>parallel</execution-mode> + </model> + </models> + </onnx> + </model-evaluation> <nodes> <node hostalias="node1" /> </nodes> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java index b369560be74..27f1fc26b15 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java @@ -123,6 +123,10 @@ public class StatelessOnnxEvaluationTest { Tensor output = evaluator.bind("input1", input1).bind("input2", input2).evaluate(); assertEquals(6.0, output.sum().asDouble(), 1e-9); + OnnxModelsConfig.Model mulModel = onnxModelsConfig.model().get(0); + assertEquals(2, mulModel.stateless_intraop_threads()); + assertEquals(-1, mulModel.stateless_interop_threads()); + assertEquals("", mulModel.stateless_execution_mode()); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index dc27c43ef70..b014f60095e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -17,12 +18,14 @@ class OnnxModel { private final String name; private final File modelFile; + private final OnnxEvaluatorOptions options; private OnnxEvaluator evaluator; - OnnxModel(String name, File modelFile) { + OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options) { this.name = name; this.modelFile = modelFile; + this.options = options; } public String name() { @@ -31,7 +34,7 @@ class OnnxModel { public void load() { if (evaluator == null) { - evaluator = new OnnxEvaluator(modelFile.getPath()); + evaluator = new OnnxEvaluator(modelFile.getPath(), options); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index fbfd34814ac..335c39e02a1 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.collections.Pair; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -182,7 +183,13 @@ public class RankProfilesConfigImporter { try { String name = onnxModelConfig.name(); File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS); - return new OnnxModel(name, file); + + OnnxEvaluatorOptions options = new OnnxEvaluatorOptions(); + options.setExecutionMode(onnxModelConfig.stateless_execution_mode()); + options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); + options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); + + return new OnnxModel(name, file, options); } catch (InterruptedException e) { throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index b782a79f14b..4c44fca8c79 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -26,14 +26,16 @@ public class OnnxEvaluator { private final OrtSession session; public OnnxEvaluator(String modelPath) { + this(modelPath, null); + } + + public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) { try { + if (options == null) { + options = new OnnxEvaluatorOptions(); + } environment = OrtEnvironment.getEnvironment(); - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT); - options.setIntraOpNumThreads(Math.max(1, Runtime.getRuntime().availableProcessors() / 4)); - options.setInterOpNumThreads(1); - options.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL); - session = environment.createSession(modelPath, options); + session = environment.createSession(modelPath, options.getOptions()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java new file mode 100644 index 00000000000..8467040e5c0 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -0,0 +1,58 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; + +/** + * Session options for ONNX Runtime evaluation + * + * @author lesters + */ +public class OnnxEvaluatorOptions { + + private OrtSession.SessionOptions.OptLevel optimizationLevel; + private OrtSession.SessionOptions.ExecutionMode executionMode; + private int interOpThreads; + private int intraOpThreads; + + public OnnxEvaluatorOptions() { + // Defaults: + optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT; + executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; + interOpThreads = 1; + intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4)); + } + + public OrtSession.SessionOptions getOptions() throws OrtException { + OrtSession.SessionOptions options = new OrtSession.SessionOptions(); + options.setOptimizationLevel(optimizationLevel); + options.setExecutionMode(executionMode); + options.setInterOpNumThreads(interOpThreads); + options.setIntraOpNumThreads(intraOpThreads); + return options; + } + + public void setExecutionMode(String mode) { + if ("parallel".equalsIgnoreCase(mode)) { + executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL; + } else if ("sequential".equalsIgnoreCase(mode)) { + executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; + } + } + + public void setInterOpThreads(int threads) { + if (threads >= 0) { + interOpThreads = threads; + } + } + + public void setIntraOpThreads(int threads) { + if (threads >= 0) { + intraOpThreads = threads; + } + } + +} diff --git a/searchcore/src/vespa/searchcore/config/onnx-models.def b/searchcore/src/vespa/searchcore/config/onnx-models.def index 33ea90002c8..2c87a5a78ad 100644 --- a/searchcore/src/vespa/searchcore/config/onnx-models.def +++ b/searchcore/src/vespa/searchcore/config/onnx-models.def @@ -1,10 +1,13 @@ # Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. namespace=vespa.config.search.core -model[].name string -model[].fileref file -model[].input[].name string -model[].input[].source string -model[].output[].name string -model[].output[].as string -model[].dry_run_on_setup bool default=false +model[].name string +model[].fileref file +model[].input[].name string +model[].input[].source string +model[].output[].name string +model[].output[].as string +model[].dry_run_on_setup bool default=false +model[].stateless_execution_mode string default="" +model[].stateless_interop_threads int default=-1 +model[].stateless_intraop_threads int default=-1 |