From 387da217a9eb2a6f88b50f3608659a7d75c66aeb Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 30 Sep 2021 16:05:03 +0200 Subject: Add parsing of ONNX Runtime session options to services.xml --- .../java/com/yahoo/searchdefinition/OnnxModel.java | 38 ++++++++++++++++++++++ .../com/yahoo/searchdefinition/OnnxModels.java | 1 + .../searchdefinition/derived/RankProfileList.java | 11 +++++++ .../model/container/xml/ContainerModelBuilder.java | 23 +++++++++++++ .../src/main/resources/schema/containercluster.rnc | 11 ++++++- .../src/test/cfg/application/onnx/services.xml | 14 +++++++- .../model/ml/StatelessOnnxEvaluationTest.java | 4 +++ 7 files changed, 100 insertions(+), 2 deletions(-) (limited to 'config-model/src') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 3c42987512b..d85c0065d84 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -8,6 +8,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Optional; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -20,6 +21,10 @@ public class OnnxModel extends DistributableResource { private Map inputMap = new HashMap<>(); private Map outputMap = new HashMap<>(); + private String statelessExecutionMode = null; + private Integer statelessInterOpThreads = null; + private Integer statelessIntraOpThreads = null; + public OnnxModel(String name) { super(name); } @@ -79,4 +84,37 @@ public class OnnxModel extends DistributableResource { TensorType getTensorType(String onnxName, Map inputTypes) { return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty; } + + public void setStatelessExecutionMode(String executionMode) { + if ("parallel".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "parallel"; + } else if ("sequential".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "sequential"; + } + } + + public Optional getStatelessExecutionMode() { + return Optional.ofNullable(statelessExecutionMode); + } + + public void setStatelessInterOpThreads(int interOpThreads) { + if (interOpThreads >= 0) { + this.statelessInterOpThreads = interOpThreads; + } + } + + public Optional getStatelessInterOpThreads() { + return Optional.ofNullable(statelessInterOpThreads); + } + + public void setStatelessIntraOpThreads(int intraOpThreads) { + if (intraOpThreads >= 0) { + this.statelessIntraOpThreads = intraOpThreads; + } + } + + public Optional getStatelessIntraOpThreads() { + return Optional.ofNullable(statelessIntraOpThreads); + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java index cb61b8c9cec..dcbea70a32b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -22,6 +22,7 @@ public class OnnxModels { public OnnxModels(FileRegistry fileRegistry) { this.fileRegistry = fileRegistry; } + public void add(OnnxModel model) { model.validate(); model.register(fileRegistry); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index ad85f68cb8a..8291c69af2f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -130,6 +130,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } + public OnnxModels getOnnxModels() { + return onnxModels; + } + public Map getRankProfiles() { return rankProfiles; } @@ -182,6 +186,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ modelBuilder.fileref(model.getFileReference()); model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + if (model.getStatelessExecutionMode().isPresent()) + modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); + if (model.getStatelessInterOpThreads().isPresent()) + modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); + if (model.getStatelessIntraOpThreads().isPresent()) + modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); + builder.model(modelBuilder); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index c318180fd56..c62dee68b2d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -33,6 +33,7 @@ import com.yahoo.config.provision.Zone; import com.yahoo.container.logging.FileConnectionLog; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.search.rendering.RendererRegistry; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.security.X509CertificateUtils; import com.yahoo.text.XML; @@ -557,9 +558,31 @@ public class ContainerModelBuilder extends ConfigModelBuilder { RankProfileList profiles = context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty; + + Element onnxElement = XML.getChild(modelEvaluationElement, "onnx"); + Element modelsElement = XML.getChild(onnxElement, "models"); + for (Element modelElement : XML.getChildren(modelsElement, "model") ) { + OnnxModel onnxModel = profiles.getOnnxModels().get(modelElement.getAttribute("name")); + if (onnxModel == null) + continue; // Skip if model is not found + onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null)); + onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1)); + onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1)); + } + cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles)); } + private String getStringValue(Element element, String name, String defaultValue) { + Element child = XML.getChild(element, name); + return (child != null) ? child.getTextContent() : defaultValue; + } + + private int getIntValue(Element element, String name, int defaultValue) { + Element child = XML.getChild(element, name); + return (child != null) ? Integer.parseInt(child.getTextContent()) : defaultValue; + } + protected void addModelEvaluationBundles(ApplicationContainerCluster cluster) { /* These bundles are added to all application container clusters, even if they haven't * declared 'model-evaluation' in services.xml, because there are many public API packages diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc index 992689a2189..945deec9f91 100644 --- a/config-model/src/main/resources/schema/containercluster.rnc +++ b/config-model/src/main/resources/schema/containercluster.rnc @@ -105,7 +105,16 @@ ZooKeeper = element zookeeper { } ModelEvaluation = element model-evaluation { - empty + element onnx { + element models { + element model { + attribute name { string } & + element intraop-threads { xsd:nonNegativeInteger }? & + element interop-threads { xsd:nonNegativeInteger }? & + element execution-mode { string "sequential" | string "parallel" }? + }* + }? + }? } Ssl = element ssl { diff --git a/config-model/src/test/cfg/application/onnx/services.xml b/config-model/src/test/cfg/application/onnx/services.xml index 8731558c6f7..1cae9250009 100644 --- a/config-model/src/test/cfg/application/onnx/services.xml +++ b/config-model/src/test/cfg/application/onnx/services.xml @@ -3,7 +3,19 @@ - + + + + + 2 + + + 400 + parallel + + + + diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java index b369560be74..27f1fc26b15 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java @@ -123,6 +123,10 @@ public class StatelessOnnxEvaluationTest { Tensor output = evaluator.bind("input1", input1).bind("input2", input2).evaluate(); assertEquals(6.0, output.sum().asDouble(), 1e-9); + OnnxModelsConfig.Model mulModel = onnxModelsConfig.model().get(0); + assertEquals(2, mulModel.stateless_intraop_threads()); + assertEquals(-1, mulModel.stateless_interop_threads()); + assertEquals("", mulModel.stateless_execution_mode()); } } -- cgit v1.2.3