diff options
author | Harald Musum <musum@yahooinc.com> | 2023-11-21 09:52:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-21 09:52:52 +0100 |
commit | 430c0f8c9e1ea5eaeae2b795cd4b7350091679ae (patch) | |
tree | 4e49701083ff1600df5775f481eb3057edbf88bf /config-model-api/src | |
parent | d998b2774ce916ce5a92f4879f3f47a23f1346a9 (diff) | |
parent | 9d28a47b003f5498bc59bfd10017dd55fc7ab6e0 (diff) |
Merge pull request #29388 from vespa-engine/hmusum/register-with-onnx-model-options
Register model with onnx model options
Diffstat (limited to 'config-model-api/src')
-rw-r--r-- | config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java | 5 | ||||
-rw-r--r-- | config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java | 49 |
2 files changed, 54 insertions, 0 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java index acb88070482..b98667457e4 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -10,6 +10,7 @@ import java.net.URI; /** * @author bjorncs */ +// TODO: Rename public interface OnnxModelCost { Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId); @@ -17,7 +18,9 @@ public interface OnnxModelCost { interface Calculator { long aggregatedModelCostInBytes(); void registerModel(ApplicationFile path); + void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions); void registerModel(URI uri); + void registerModel(URI uri, OnnxModelOptions onnxModelOptions); } static OnnxModelCost disabled() { return new DisabledOnnxModelCost(); } @@ -26,7 +29,9 @@ public interface OnnxModelCost { @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; } @Override public long aggregatedModelCostInBytes() {return 0;} @Override public void registerModel(ApplicationFile path) {} + @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} @Override public void registerModel(URI uri) {} + @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {} } } diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java new file mode 100644 index 00000000000..92817baae3f --- /dev/null +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java @@ -0,0 +1,49 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.model.api; + +import java.util.Optional; + +/** + * Onnx model options that are relevant when deciding if an Onnx model needs to be reloaded. If any of the + * values in this class change, reload is needed. + * + * @author hmusum + */ +public record OnnxModelOptions(Optional<String> executionMode, Optional<Integer> interOpThreads, + Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice) { + + public OnnxModelOptions(String executionMode, int interOpThreads, int intraOpThreads, GpuDevice gpuDevice) { + this(Optional.of(executionMode), Optional.of(interOpThreads), Optional.of(intraOpThreads), Optional.of(gpuDevice)); + } + + public static OnnxModelOptions empty() { + return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + } + + public OnnxModelOptions withExecutionMode(String executionMode) { + return new OnnxModelOptions(Optional.ofNullable(executionMode), interOpThreads, intraOpThreads, gpuDevice); + } + + public OnnxModelOptions withInterOpThreads(Integer interOpThreads) { + return new OnnxModelOptions(executionMode, Optional.ofNullable(interOpThreads), intraOpThreads, gpuDevice); + } + + public OnnxModelOptions withIntraOpThreads(Integer intraOpThreads) { + return new OnnxModelOptions(executionMode, interOpThreads, Optional.ofNullable(intraOpThreads), gpuDevice); + } + + public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) { + return new OnnxModelOptions(executionMode, interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice)); + } + + public record GpuDevice(int deviceNumber, boolean required) { + public GpuDevice { + if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); + } + + public GpuDevice(int deviceNumber) { + this(deviceNumber, false); + } + } + +} |