summaryrefslogtreecommitdiffstats
path: root/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java')
-rw-r--r--config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java61
1 files changed, 57 insertions, 4 deletions
diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java
index 776e5c621bb..836bf5f72d0 100644
--- a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java
+++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java
@@ -120,10 +120,48 @@ public class NodeResources {
}
+ public record GpuResources(int count, double memoryGb) {
+
+ private static final GpuResources none = new GpuResources(0, 0);
+
+ public GpuResources {
+ if (count < 0) throw new IllegalArgumentException("GPU count cannot be negative, got " + count);
+ if (memoryGb < 0) throw new IllegalArgumentException("GPU memory cannot be negative, got " + memoryGb);
+ validate(memoryGb, "memory");
+ }
+
+ private double totalMemory() {
+ return count * memoryGb;
+ }
+
+ public boolean lessThan(GpuResources other) {
+ return totalMemory() < other.totalMemory();
+ }
+
+ public boolean isDefault() { return this.equals(getDefault()); }
+
+ public static GpuResources getDefault() { return none; }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ GpuResources that = (GpuResources) o;
+ return count == that.count && equal(this.memoryGb, that.memoryGb);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(count, memoryGb);
+ }
+
+ }
+
private final double vcpu;
private final double memoryGb;
private final double diskGb;
private final double bandwidthGbps;
+ private final GpuResources gpuResources;
private final DiskSpeed diskSpeed;
private final StorageType storageType;
private final Architecture architecture;
@@ -133,18 +171,23 @@ public class NodeResources {
}
public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed) {
- this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, StorageType.getDefault(), Architecture.getDefault());
+ this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, StorageType.getDefault(), Architecture.getDefault(), GpuResources.getDefault());
}
public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType) {
- this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, storageType, Architecture.getDefault());
+ this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, storageType, Architecture.getDefault(), GpuResources.getDefault());
}
public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType, Architecture architecture) {
+ this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, storageType, architecture, GpuResources.getDefault());
+ }
+
+ public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType, Architecture architecture, GpuResources gpuResources) {
this.vcpu = validate(vcpu, "vcpu");
this.memoryGb = validate(memoryGb, "memory");
this.diskGb = validate(diskGb, "disk");
this.bandwidthGbps = validate(bandwidthGbps, "bandwidth");
+ this.gpuResources = gpuResources;
this.diskSpeed = diskSpeed;
this.storageType = storageType;
this.architecture = architecture;
@@ -157,6 +200,7 @@ public class NodeResources {
public DiskSpeed diskSpeed() { return diskSpeed; }
public StorageType storageType() { return storageType; }
public Architecture architecture() { return architecture; }
+ public GpuResources gpuResources() { return gpuResources; }
/** Returns the standard cost of these resources, in dollars per hour */
public double cost() {
@@ -264,6 +308,7 @@ public class NodeResources {
if ( ! equal(this.memoryGb, other.memoryGb)) return false;
if ( ! equal(this.diskGb, other.diskGb)) return false;
if ( ! equal(this.bandwidthGbps, other.bandwidthGbps)) return false;
+ if ( ! this.gpuResources.equals(other.gpuResources)) return false;
if (this.diskSpeed != other.diskSpeed) return false;
if (this.storageType != other.storageType) return false;
if (this.architecture != other.architecture) return false;
@@ -305,6 +350,12 @@ public class NodeResources {
sb.append(", storage type: ").append(storageType);
}
sb.append(", architecture: ").append(architecture);
+ if ( !gpuResources.isDefault()) {
+ sb.append(", gpu count: ").append(gpuResources.count());
+ sb.append(", gpu memory: ");
+ appendDouble(sb, memoryGb);
+ sb.append(" Gb");
+ }
sb.append(']');
return sb.toString();
}
@@ -317,6 +368,7 @@ public class NodeResources {
if (this.memoryGb < other.memoryGb) return false;
if (this.diskGb < other.diskGb) return false;
if (this.bandwidthGbps < other.bandwidthGbps) return false;
+ if (this.gpuResources.lessThan(other.gpuResources)) return false;
// Why doesn't a fast disk satisfy a slow disk? Because if slow disk is explicitly specified
// (i.e not "any"), you should not randomly, sometimes get a faster disk as that means you may
@@ -364,6 +416,7 @@ public class NodeResources {
if ( ! equal(this.memoryGb, other.memoryGb)) return false;
if ( ! equal(this.diskGb, other.diskGb)) return false;
if ( ! equal(this.bandwidthGbps, other.bandwidthGbps)) return false;
+ if ( ! this.gpuResources.equals(other.gpuResources)) return false;
if ( ! this.diskSpeed.compatibleWith(other.diskSpeed)) return false;
if ( ! this.storageType.compatibleWith(other.storageType)) return false;
if ( ! this.architecture.compatibleWith(other.architecture)) return false;
@@ -396,7 +449,7 @@ public class NodeResources {
return this.isUnspecified() ? Optional.empty() : Optional.of(this);
}
- private boolean equal(double a, double b) {
+ private static boolean equal(double a, double b) {
return Math.abs(a - b) < 0.00000001;
}
@@ -421,7 +474,7 @@ public class NodeResources {
return new NodeResources(cpu, mem, dsk, 0.3, DiskSpeed.getDefault(), StorageType.getDefault(), Architecture.x86_64);
}
- private double validate(double value, String valueName) {
+ private static double validate(double value, String valueName) {
if (Double.isNaN(value)) throw new IllegalArgumentException(valueName + " cannot be NaN");
if (Double.isInfinite(value)) throw new IllegalArgumentException(valueName + " cannot be infinite");
return value;