aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-02-07 16:04:57 +0100
committerMartin Polden <mpolden@mpolden.no>2023-02-08 09:52:53 +0100
commitc47f27f1c3362b459e276c59ebcd09ab259b710e (patch)
treef68963b9b91558d9418da0cb88e5a16d67a52034 /config-model/src/main/java/com/yahoo/schema/OnnxModel.java
parentd03bf2ed1b239f4998bdfd6580965bcd0a7d62a4 (diff)
Allow fallback to CPU if nodes are provisioned without GPU
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/OnnxModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java16
1 files changed, 12 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 3d96849fa15..ae6f1fd96e4 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -24,7 +24,7 @@ public class OnnxModel extends DistributableResource {
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
private Integer statelessIntraOpThreads = null;
- private Integer gpuDevice = null;
+ private GpuDevice gpuDevice = null;
public OnnxModel(String name) {
super(name);
@@ -114,9 +114,9 @@ public class OnnxModel extends DistributableResource {
}
}
- public void setGpuDevice(int deviceNumber) {
+ public void setGpuDevice(int deviceNumber, boolean required) {
if (deviceNumber >= 0) {
- this.gpuDevice = deviceNumber;
+ this.gpuDevice = new GpuDevice(deviceNumber, required);
}
}
@@ -124,8 +124,16 @@ public class OnnxModel extends DistributableResource {
return Optional.ofNullable(statelessIntraOpThreads);
}
- public Optional<Integer> getGpuDevice() {
+ public Optional<GpuDevice> getGpuDevice() {
return Optional.ofNullable(gpuDevice);
}
+ public record GpuDevice(int deviceNumber, boolean required) {
+
+ public GpuDevice {
+ if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
+ }
+
+ }
+
}