diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-03-22 23:34:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-22 23:34:53 +0100 |
commit | a9a6d2275c49f5690791cbb50648589ea800a146 (patch) | |
tree | 79ebd0a3bf9f2944b3bb2c26153e3df4021461ae /config-model/src/main | |
parent | e260f413fe355b0ddb39a86a77f49accc5e738b6 (diff) | |
parent | a06c1cf91899f5da327c408c61c798ffddfd32da (diff) |
Merge pull request #26537 from vespa-engine/arnej/add-stateless-settings-in-schemav8.144.19
Arnej/add stateless settings in schema
Diffstat (limited to 'config-model/src/main')
4 files changed, 18 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/Derived.java b/config-model/src/main/java/com/yahoo/schema/derived/Derived.java index 9943a02a2f2..e8b12f22b20 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/Derived.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/Derived.java @@ -95,7 +95,7 @@ public abstract class Derived implements Exportable { * @param toDirectory the directory to export to, or null * */ - public final void export(String toDirectory) throws IOException { + public void export(String toDirectory) throws IOException { Writer writer = null; try { String fileName = getDerivedName() + ".cfg"; diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java index 4196af18fb6..e3c697e3262 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java @@ -16,10 +16,13 @@ import java.util.logging.Logger; * * @author bratseth */ -public class FileDistributedOnnxModels { +public class FileDistributedOnnxModels extends Derived implements OnnxModelsConfig.Producer { private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName()); + @Override + public String getDerivedName() { return "onnx-models"; } + private final Map<String, OnnxModel> models; public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) { diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/schema/derived/RankProfileList.java index c254385a96e..a50ddd4aeea 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RankProfileList.java @@ -193,6 +193,12 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ public String getDerivedName() { return "rank-profiles"; } @Override + public void export(String toDirectory) throws java.io.IOException { + super.export(toDirectory); + onnxModels.export(toDirectory); + } + + @Override public void getConfig(RankProfilesConfig.Builder builder) { for (RawRankProfile rank : rankProfiles.values() ) { rank.getConfig(builder); diff --git a/config-model/src/main/javacc/SchemaParser.jj b/config-model/src/main/javacc/SchemaParser.jj index fa9d34139ea..61e8574bc87 100644 --- a/config-model/src/main/javacc/SchemaParser.jj +++ b/config-model/src/main/javacc/SchemaParser.jj @@ -186,6 +186,9 @@ TOKEN : | < SUFFIX: "suffix" > | < CONSTANT: "constant"> | < ONNXMODEL: "onnx-model"> +| < INTRAOPTHREADS: "intraop-threads"> +| < INTEROPTHREADS: "interop-threads"> +| < GPUDEVICE: "gpu-device"> | < MODEL: "model" > | < MUTATE: "mutate" > | < QUERY: "query" > @@ -1594,11 +1597,15 @@ OnnxModel onnxModel() : void onnxModelItem(OnnxModel onnxModel) : { String path = null; + int num; } { ( (path = fileItem()) { onnxModel.setFileName(path); } | (path = uriItem()) { onnxModel.setUri(path); } | + <GPUDEVICE> <COLON> num = integer() { onnxModel.setGpuDevice(num, false); } | + <INTRAOPTHREADS> <COLON> num = integer() { onnxModel.setStatelessIntraOpThreads(num); } | + <INTEROPTHREADS> <COLON> num = integer() { onnxModel.setStatelessInterOpThreads(num); } | (<ONNX_INPUT_SL>) { String name = token.image.substring(5, token.image.lastIndexOf(":")).trim(); if (name.startsWith("\"")) { name = name.substring(1, name.length() - 1); } |