diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/OnnxModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/OnnxModel.java | 32 |
1 files changed, 13 insertions, 19 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index f3f09150c1d..9456baafd57 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.ml.OnnxModelInfo; @@ -27,10 +28,7 @@ public class OnnxModel extends DistributableResource implements Cloneable { private final Set<String> initializers = new HashSet<>(); // Runtime options - private String statelessExecutionMode = null; - private Integer statelessInterOpThreads = null; - private Integer statelessIntraOpThreads = null; - private GpuDevice gpuDevice = null; + private OnnxModelOptions onnxModelOptions = OnnxModelOptions.empty(); public OnnxModel(String name) { super(name); @@ -133,50 +131,46 @@ public class OnnxModel extends DistributableResource implements Cloneable { public void setStatelessExecutionMode(String executionMode) { if ("parallel".equalsIgnoreCase(executionMode)) { - this.statelessExecutionMode = "parallel"; + onnxModelOptions = onnxModelOptions.withExecutionMode("parallel"); } else if ("sequential".equalsIgnoreCase(executionMode)) { - this.statelessExecutionMode = "sequential"; + onnxModelOptions = onnxModelOptions.withExecutionMode("sequential"); } } public Optional<String> getStatelessExecutionMode() { - return Optional.ofNullable(statelessExecutionMode); + return onnxModelOptions.executionMode(); } public void setStatelessInterOpThreads(int interOpThreads) { if (interOpThreads >= 0) { - this.statelessInterOpThreads = interOpThreads; + onnxModelOptions = onnxModelOptions.withInterOpThreads(interOpThreads); } } public Optional<Integer> getStatelessInterOpThreads() { - return Optional.ofNullable(statelessInterOpThreads); + return onnxModelOptions.interOpThreads(); } public void setStatelessIntraOpThreads(int intraOpThreads) { if (intraOpThreads >= 0) { - this.statelessIntraOpThreads = intraOpThreads; + onnxModelOptions = onnxModelOptions.withIntraOpThreads(intraOpThreads); } } public Optional<Integer> getStatelessIntraOpThreads() { - return Optional.ofNullable(statelessIntraOpThreads); + return onnxModelOptions.intraOpThreads(); } public void setGpuDevice(int deviceNumber, boolean required) { if (deviceNumber >= 0) { - this.gpuDevice = new GpuDevice(deviceNumber, required); + onnxModelOptions = onnxModelOptions.withGpuDevice(new OnnxModelOptions.GpuDevice(deviceNumber, required)); } } - public Optional<GpuDevice> getGpuDevice() { - return Optional.ofNullable(gpuDevice); + public Optional<OnnxModelOptions.GpuDevice> getGpuDevice() { + return onnxModelOptions.gpuDevice(); } - public record GpuDevice(int deviceNumber, boolean required) { - public GpuDevice { - if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); - } - } + public OnnxModelOptions onnxModelOptions() { return onnxModelOptions; } } |