summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/OnnxModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java23
1 files changed, 16 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 90a27d1f036..3295b2e93aa 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -18,13 +18,15 @@ import java.util.Set;
*
* @author lesters
*/
-public class OnnxModel extends DistributableResource {
+public class OnnxModel extends DistributableResource implements Cloneable {
+ // Model information
private OnnxModelInfo modelInfo = null;
private final Map<String, String> inputMap = new HashMap<>();
private final Map<String, String> outputMap = new HashMap<>();
private final Set<String> initializers = new HashSet<>();
+ // Runtime options
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
private Integer statelessIntraOpThreads = null;
@@ -40,6 +42,15 @@ public class OnnxModel extends DistributableResource {
}
@Override
+ public OnnxModel clone() {
+ try {
+ return (OnnxModel) super.clone(); // Shallow clone is sufficient here
+ } catch (CloneNotSupportedException e) {
+ throw new RuntimeException("Clone not supported", e);
+ }
+ }
+
+ @Override
public void setUri(String uri) {
throw new IllegalArgumentException("URI for ONNX models are not currently supported");
}
@@ -148,26 +159,24 @@ public class OnnxModel extends DistributableResource {
}
}
+ public Optional<Integer> getStatelessIntraOpThreads() {
+ return Optional.ofNullable(statelessIntraOpThreads);
+ }
+
public void setGpuDevice(int deviceNumber, boolean required) {
if (deviceNumber >= 0) {
this.gpuDevice = new GpuDevice(deviceNumber, required);
}
}
- public Optional<Integer> getStatelessIntraOpThreads() {
- return Optional.ofNullable(statelessIntraOpThreads);
- }
-
public Optional<GpuDevice> getGpuDevice() {
return Optional.ofNullable(gpuDevice);
}
public record GpuDevice(int deviceNumber, boolean required) {
-
public GpuDevice {
if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
}
-
}
}