diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/OnnxModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/OnnxModel.java | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 90a27d1f036..3295b2e93aa 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -18,13 +18,15 @@ import java.util.Set; * * @author lesters */ -public class OnnxModel extends DistributableResource { +public class OnnxModel extends DistributableResource implements Cloneable { + // Model information private OnnxModelInfo modelInfo = null; private final Map<String, String> inputMap = new HashMap<>(); private final Map<String, String> outputMap = new HashMap<>(); private final Set<String> initializers = new HashSet<>(); + // Runtime options private String statelessExecutionMode = null; private Integer statelessInterOpThreads = null; private Integer statelessIntraOpThreads = null; @@ -40,6 +42,15 @@ public class OnnxModel extends DistributableResource { } @Override + public OnnxModel clone() { + try { + return (OnnxModel) super.clone(); // Shallow clone is sufficient here + } catch (CloneNotSupportedException e) { + throw new RuntimeException("Clone not supported", e); + } + } + + @Override public void setUri(String uri) { throw new IllegalArgumentException("URI for ONNX models are not currently supported"); } @@ -148,26 +159,24 @@ public class OnnxModel extends DistributableResource { } } + public Optional<Integer> getStatelessIntraOpThreads() { + return Optional.ofNullable(statelessIntraOpThreads); + } + public void setGpuDevice(int deviceNumber, boolean required) { if (deviceNumber >= 0) { this.gpuDevice = new GpuDevice(deviceNumber, required); } } - public Optional<Integer> getStatelessIntraOpThreads() { - return Optional.ofNullable(statelessIntraOpThreads); - } - public Optional<GpuDevice> getGpuDevice() { return Optional.ofNullable(gpuDevice); } public record GpuDevice(int deviceNumber, boolean required) { - public GpuDevice { if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); } - } } |