diff options
author | Harald Musum <musum@yahooinc.com> | 2023-11-20 15:17:28 +0100 |
---|---|---|
committer | Harald Musum <musum@yahooinc.com> | 2023-11-20 15:17:28 +0100 |
commit | 9d28a47b003f5498bc59bfd10017dd55fc7ab6e0 (patch) | |
tree | add8988cb32377e8395c36d2fb46cc9259b5dd63 /config-model-api | |
parent | e8c0a04b67b632ea3f98327d8f39cd0293ad8581 (diff) |
Register model with onnx model options
Diffstat (limited to 'config-model-api')
3 files changed, 105 insertions, 2 deletions
diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json index bf2f2cfac44..cab011e47be 100644 --- a/config-model-api/abi-spec.json +++ b/config-model-api/abi-spec.json @@ -1455,7 +1455,9 @@ "methods" : [ "public abstract long aggregatedModelCostInBytes()", "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile)", - "public abstract void registerModel(java.net.URI)" + "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)", + "public abstract void registerModel(java.net.URI)", + "public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)" ], "fields" : [ ] }, @@ -1473,7 +1475,9 @@ "public com.yahoo.config.model.api.OnnxModelCost$Calculator newCalculator(com.yahoo.config.application.api.ApplicationPackage, com.yahoo.config.provision.ApplicationId)", "public long aggregatedModelCostInBytes()", "public void registerModel(com.yahoo.config.application.api.ApplicationFile)", - "public void registerModel(java.net.URI)" + "public void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)", + "public void registerModel(java.net.URI)", + "public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)" ], "fields" : [ ] }, @@ -1491,6 +1495,51 @@ ], "fields" : [ ] }, + "com.yahoo.config.model.api.OnnxModelOptions$GpuDevice" : { + "superClass" : "java.lang.Record", + "interfaces" : [ ], + "attributes" : [ + "public", + "final", + "record" + ], + "methods" : [ + "public void <init>(int, boolean)", + "public void <init>(int)", + "public final java.lang.String toString()", + "public final int hashCode()", + "public final boolean equals(java.lang.Object)", + "public int deviceNumber()", + "public boolean required()" + ], + "fields" : [ ] + }, + "com.yahoo.config.model.api.OnnxModelOptions" : { + "superClass" : "java.lang.Record", + "interfaces" : [ ], + "attributes" : [ + "public", + "final", + "record" + ], + "methods" : [ + "public void <init>(java.lang.String, int, int, com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)", + "public void <init>(java.util.Optional, java.util.Optional, java.util.Optional, java.util.Optional)", + "public static com.yahoo.config.model.api.OnnxModelOptions empty()", + "public com.yahoo.config.model.api.OnnxModelOptions withExecutionMode(java.lang.String)", + "public com.yahoo.config.model.api.OnnxModelOptions withInterOpThreads(java.lang.Integer)", + "public com.yahoo.config.model.api.OnnxModelOptions withIntraOpThreads(java.lang.Integer)", + "public com.yahoo.config.model.api.OnnxModelOptions withGpuDevice(com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)", + "public final java.lang.String toString()", + "public final int hashCode()", + "public final boolean equals(java.lang.Object)", + "public java.util.Optional executionMode()", + "public java.util.Optional interOpThreads()", + "public java.util.Optional intraOpThreads()", + "public java.util.Optional gpuDevice()" + ], + "fields" : [ ] + }, "com.yahoo.config.model.api.PortInfo" : { "superClass" : "java.lang.Object", "interfaces" : [ ], diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java index acb88070482..b98667457e4 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -10,6 +10,7 @@ import java.net.URI; /** * @author bjorncs */ +// TODO: Rename public interface OnnxModelCost { Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId); @@ -17,7 +18,9 @@ public interface OnnxModelCost { interface Calculator { long aggregatedModelCostInBytes(); void registerModel(ApplicationFile path); + void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions); void registerModel(URI uri); + void registerModel(URI uri, OnnxModelOptions onnxModelOptions); } static OnnxModelCost disabled() { return new DisabledOnnxModelCost(); } @@ -26,7 +29,9 @@ public interface OnnxModelCost { @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; } @Override public long aggregatedModelCostInBytes() {return 0;} @Override public void registerModel(ApplicationFile path) {} + @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} @Override public void registerModel(URI uri) {} + @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {} } } diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java new file mode 100644 index 00000000000..92817baae3f --- /dev/null +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java @@ -0,0 +1,49 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.model.api; + +import java.util.Optional; + +/** + * Onnx model options that are relevant when deciding if an Onnx model needs to be reloaded. If any of the + * values in this class change, reload is needed. + * + * @author hmusum + */ +public record OnnxModelOptions(Optional<String> executionMode, Optional<Integer> interOpThreads, + Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice) { + + public OnnxModelOptions(String executionMode, int interOpThreads, int intraOpThreads, GpuDevice gpuDevice) { + this(Optional.of(executionMode), Optional.of(interOpThreads), Optional.of(intraOpThreads), Optional.of(gpuDevice)); + } + + public static OnnxModelOptions empty() { + return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + } + + public OnnxModelOptions withExecutionMode(String executionMode) { + return new OnnxModelOptions(Optional.ofNullable(executionMode), interOpThreads, intraOpThreads, gpuDevice); + } + + public OnnxModelOptions withInterOpThreads(Integer interOpThreads) { + return new OnnxModelOptions(executionMode, Optional.ofNullable(interOpThreads), intraOpThreads, gpuDevice); + } + + public OnnxModelOptions withIntraOpThreads(Integer intraOpThreads) { + return new OnnxModelOptions(executionMode, interOpThreads, Optional.ofNullable(intraOpThreads), gpuDevice); + } + + public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) { + return new OnnxModelOptions(executionMode, interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice)); + } + + public record GpuDevice(int deviceNumber, boolean required) { + public GpuDevice { + if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); + } + + public GpuDevice(int deviceNumber) { + this(deviceNumber, false); + } + } + +} |