aboutsummaryrefslogtreecommitdiffstats
path: root/config-model-api
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2023-11-21 09:52:52 +0100
committerGitHub <noreply@github.com>2023-11-21 09:52:52 +0100
commit430c0f8c9e1ea5eaeae2b795cd4b7350091679ae (patch)
tree4e49701083ff1600df5775f481eb3057edbf88bf /config-model-api
parentd998b2774ce916ce5a92f4879f3f47a23f1346a9 (diff)
parent9d28a47b003f5498bc59bfd10017dd55fc7ab6e0 (diff)
Merge pull request #29388 from vespa-engine/hmusum/register-with-onnx-model-options
Register model with onnx model options
Diffstat (limited to 'config-model-api')
-rw-r--r--config-model-api/abi-spec.json53
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java5
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java49
3 files changed, 105 insertions, 2 deletions
diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json
index d9c68c89189..78b32d8af7b 100644
--- a/config-model-api/abi-spec.json
+++ b/config-model-api/abi-spec.json
@@ -1453,7 +1453,9 @@
"methods" : [
"public abstract long aggregatedModelCostInBytes()",
"public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile)",
- "public abstract void registerModel(java.net.URI)"
+ "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)",
+ "public abstract void registerModel(java.net.URI)",
+ "public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)"
],
"fields" : [ ]
},
@@ -1471,7 +1473,9 @@
"public com.yahoo.config.model.api.OnnxModelCost$Calculator newCalculator(com.yahoo.config.application.api.ApplicationPackage, com.yahoo.config.provision.ApplicationId)",
"public long aggregatedModelCostInBytes()",
"public void registerModel(com.yahoo.config.application.api.ApplicationFile)",
- "public void registerModel(java.net.URI)"
+ "public void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)",
+ "public void registerModel(java.net.URI)",
+ "public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)"
],
"fields" : [ ]
},
@@ -1489,6 +1493,51 @@
],
"fields" : [ ]
},
+ "com.yahoo.config.model.api.OnnxModelOptions$GpuDevice" : {
+ "superClass" : "java.lang.Record",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final",
+ "record"
+ ],
+ "methods" : [
+ "public void <init>(int, boolean)",
+ "public void <init>(int)",
+ "public final java.lang.String toString()",
+ "public final int hashCode()",
+ "public final boolean equals(java.lang.Object)",
+ "public int deviceNumber()",
+ "public boolean required()"
+ ],
+ "fields" : [ ]
+ },
+ "com.yahoo.config.model.api.OnnxModelOptions" : {
+ "superClass" : "java.lang.Record",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final",
+ "record"
+ ],
+ "methods" : [
+ "public void <init>(java.lang.String, int, int, com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)",
+ "public void <init>(java.util.Optional, java.util.Optional, java.util.Optional, java.util.Optional)",
+ "public static com.yahoo.config.model.api.OnnxModelOptions empty()",
+ "public com.yahoo.config.model.api.OnnxModelOptions withExecutionMode(java.lang.String)",
+ "public com.yahoo.config.model.api.OnnxModelOptions withInterOpThreads(java.lang.Integer)",
+ "public com.yahoo.config.model.api.OnnxModelOptions withIntraOpThreads(java.lang.Integer)",
+ "public com.yahoo.config.model.api.OnnxModelOptions withGpuDevice(com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)",
+ "public final java.lang.String toString()",
+ "public final int hashCode()",
+ "public final boolean equals(java.lang.Object)",
+ "public java.util.Optional executionMode()",
+ "public java.util.Optional interOpThreads()",
+ "public java.util.Optional intraOpThreads()",
+ "public java.util.Optional gpuDevice()"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.config.model.api.PortInfo" : {
"superClass" : "java.lang.Object",
"interfaces" : [ ],
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
index acb88070482..b98667457e4 100644
--- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
@@ -10,6 +10,7 @@ import java.net.URI;
/**
* @author bjorncs
*/
+// TODO: Rename
public interface OnnxModelCost {
Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId);
@@ -17,7 +18,9 @@ public interface OnnxModelCost {
interface Calculator {
long aggregatedModelCostInBytes();
void registerModel(ApplicationFile path);
+ void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions);
void registerModel(URI uri);
+ void registerModel(URI uri, OnnxModelOptions onnxModelOptions);
}
static OnnxModelCost disabled() { return new DisabledOnnxModelCost(); }
@@ -26,7 +29,9 @@ public interface OnnxModelCost {
@Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; }
@Override public long aggregatedModelCostInBytes() {return 0;}
@Override public void registerModel(ApplicationFile path) {}
+ @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {}
@Override public void registerModel(URI uri) {}
+ @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {}
}
}
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java
new file mode 100644
index 00000000000..92817baae3f
--- /dev/null
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java
@@ -0,0 +1,49 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.config.model.api;
+
+import java.util.Optional;
+
+/**
+ * Onnx model options that are relevant when deciding if an Onnx model needs to be reloaded. If any of the
+ * values in this class change, reload is needed.
+ *
+ * @author hmusum
+ */
+public record OnnxModelOptions(Optional<String> executionMode, Optional<Integer> interOpThreads,
+ Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice) {
+
+ public OnnxModelOptions(String executionMode, int interOpThreads, int intraOpThreads, GpuDevice gpuDevice) {
+ this(Optional.of(executionMode), Optional.of(interOpThreads), Optional.of(intraOpThreads), Optional.of(gpuDevice));
+ }
+
+ public static OnnxModelOptions empty() {
+ return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
+ }
+
+ public OnnxModelOptions withExecutionMode(String executionMode) {
+ return new OnnxModelOptions(Optional.ofNullable(executionMode), interOpThreads, intraOpThreads, gpuDevice);
+ }
+
+ public OnnxModelOptions withInterOpThreads(Integer interOpThreads) {
+ return new OnnxModelOptions(executionMode, Optional.ofNullable(interOpThreads), intraOpThreads, gpuDevice);
+ }
+
+ public OnnxModelOptions withIntraOpThreads(Integer intraOpThreads) {
+ return new OnnxModelOptions(executionMode, interOpThreads, Optional.ofNullable(intraOpThreads), gpuDevice);
+ }
+
+ public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) {
+ return new OnnxModelOptions(executionMode, interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice));
+ }
+
+ public record GpuDevice(int deviceNumber, boolean required) {
+ public GpuDevice {
+ if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
+ }
+
+ public GpuDevice(int deviceNumber) {
+ this(deviceNumber, false);
+ }
+ }
+
+}