aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2023-08-30 10:15:39 +0200
committerLester Solbakken <lesters@oath.com>2023-08-30 10:15:39 +0200
commitcefce7443669e8dcd03c1d4905e875bef2fe83cb (patch)
tree14ff461f0f00bc9008e0e065866c585bb01306c0 /config-model
parent48a3511b28731a4b45089bdd4808167c58824b00 (diff)
Shallow clone is sufficient for onnxmodel
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java18
1 files changed, 7 insertions, 11 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index fbec97c797d..3295b2e93aa 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -20,11 +20,13 @@ import java.util.Set;
*/
public class OnnxModel extends DistributableResource implements Cloneable {
+ // Model information
private OnnxModelInfo modelInfo = null;
private final Map<String, String> inputMap = new HashMap<>();
private final Map<String, String> outputMap = new HashMap<>();
private final Set<String> initializers = new HashSet<>();
+ // Runtime options
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
private Integer statelessIntraOpThreads = null;
@@ -42,11 +44,7 @@ public class OnnxModel extends DistributableResource implements Cloneable {
@Override
public OnnxModel clone() {
try {
- OnnxModel clone = (OnnxModel) super.clone();
- clone.inputMap.putAll(inputMap);
- clone.outputMap.putAll(outputMap);
- clone.initializers.addAll(initializers);
- return clone;
+ return (OnnxModel) super.clone(); // Shallow clone is sufficient here
} catch (CloneNotSupportedException e) {
throw new RuntimeException("Clone not supported", e);
}
@@ -161,26 +159,24 @@ public class OnnxModel extends DistributableResource implements Cloneable {
}
}
+ public Optional<Integer> getStatelessIntraOpThreads() {
+ return Optional.ofNullable(statelessIntraOpThreads);
+ }
+
public void setGpuDevice(int deviceNumber, boolean required) {
if (deviceNumber >= 0) {
this.gpuDevice = new GpuDevice(deviceNumber, required);
}
}
- public Optional<Integer> getStatelessIntraOpThreads() {
- return Optional.ofNullable(statelessIntraOpThreads);
- }
-
public Optional<GpuDevice> getGpuDevice() {
return Optional.ofNullable(gpuDevice);
}
public record GpuDevice(int deviceNumber, boolean required) {
-
public GpuDevice {
if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
}
-
}
}