diff options
author | Martin Polden <mpolden@mpolden.no> | 2023-02-07 16:04:57 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2023-02-08 09:52:53 +0100 |
commit | c47f27f1c3362b459e276c59ebcd09ab259b710e (patch) | |
tree | f68963b9b91558d9418da0cb88e5a16d67a52034 /config-model/src/main/java/com/yahoo/schema/OnnxModel.java | |
parent | d03bf2ed1b239f4998bdfd6580965bcd0a7d62a4 (diff) |
Allow fallback to CPU if nodes are provisioned without GPU
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/OnnxModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/OnnxModel.java | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 3d96849fa15..ae6f1fd96e4 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -24,7 +24,7 @@ public class OnnxModel extends DistributableResource { private String statelessExecutionMode = null; private Integer statelessInterOpThreads = null; private Integer statelessIntraOpThreads = null; - private Integer gpuDevice = null; + private GpuDevice gpuDevice = null; public OnnxModel(String name) { super(name); @@ -114,9 +114,9 @@ public class OnnxModel extends DistributableResource { } } - public void setGpuDevice(int deviceNumber) { + public void setGpuDevice(int deviceNumber, boolean required) { if (deviceNumber >= 0) { - this.gpuDevice = deviceNumber; + this.gpuDevice = new GpuDevice(deviceNumber, required); } } @@ -124,8 +124,16 @@ public class OnnxModel extends DistributableResource { return Optional.ofNullable(statelessIntraOpThreads); } - public Optional<Integer> getGpuDevice() { + public Optional<GpuDevice> getGpuDevice() { return Optional.ofNullable(gpuDevice); } + public record GpuDevice(int deviceNumber, boolean required) { + + public GpuDevice { + if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); + } + + } + } |