summaryrefslogtreecommitdiffstats
path: root/config-provisioning
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2022-11-16 15:36:09 +0100
committerMartin Polden <mpolden@mpolden.no>2022-11-17 11:57:56 +0100
commitde749a8872fd1d96b7ae7b194b80764abe2768ed (patch)
tree104addc41434e6ea169faf5d754bfd1dfba6212b /config-provisioning
parentd991ae45a85bf5520a5043ccc07d4d21424cba97 (diff)
Support GPU in node specification
Diffstat (limited to 'config-provisioning')
-rw-r--r--config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java3
-rw-r--r--config-provisioning/src/main/java/com/yahoo/config/provision/Flavor.java3
-rw-r--r--config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java61
-rw-r--r--config-provisioning/src/main/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializer.java88
-rw-r--r--config-provisioning/src/main/resources/configdefinitions/config.provisioning.flavors.def4
-rw-r--r--config-provisioning/src/test/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializerTest.java3
6 files changed, 116 insertions, 46 deletions
diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java b/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java
index 70e88418fb7..dd8d181e1df 100644
--- a/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java
+++ b/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java
@@ -3,6 +3,7 @@ package com.yahoo.config.provision;
import java.util.Objects;
import java.util.Optional;
+import java.util.stream.Stream;
/**
* A capacity request.
@@ -25,6 +26,8 @@ public final class Capacity {
if (max.smallerThan(min))
throw new IllegalArgumentException("The max capacity must be larger than the min capacity, but got min " +
min + " and max " + max);
+ if (!min.equals(max) && Stream.of(min, max).anyMatch(cr -> !cr.nodeResources().gpuResources().isDefault()))
+ throw new IllegalArgumentException("Capacity range does not allow GPU, got min " + min + " and max " + max);
this.min = min;
this.max = max;
this.required = required;
diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/Flavor.java b/config-provisioning/src/main/java/com/yahoo/config/provision/Flavor.java
index 1ae4974a4c2..4539aec58eb 100644
--- a/config-provisioning/src/main/java/com/yahoo/config/provision/Flavor.java
+++ b/config-provisioning/src/main/java/com/yahoo/config/provision/Flavor.java
@@ -38,7 +38,8 @@ public class Flavor {
flavorConfig.bandwidth() / 1000,
flavorConfig.fastDisk() ? NodeResources.DiskSpeed.fast : NodeResources.DiskSpeed.slow,
flavorConfig.remoteStorage() ? NodeResources.StorageType.remote : NodeResources.StorageType.local,
- NodeResources.Architecture.valueOf(flavorConfig.architecture())),
+ NodeResources.Architecture.valueOf(flavorConfig.architecture()),
+ new NodeResources.GpuResources(flavorConfig.gpuCount(), flavorConfig.gpuMemoryGb())),
Optional.empty(),
Type.valueOf(flavorConfig.environment()),
true,
diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java
index 25771f1906b..e6bd7c70a82 100644
--- a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java
+++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java
@@ -120,10 +120,48 @@ public class NodeResources {
}
+ public record GpuResources(int count, double memoryGb) {
+
+ private static final GpuResources none = new GpuResources(0, 0);
+
+ public GpuResources {
+ if (count < 0) throw new IllegalArgumentException("GPU count cannot be negative, got " + count);
+ if (memoryGb < 0) throw new IllegalArgumentException("GPU memory cannot be negative, got " + memoryGb);
+ validate(memoryGb, "memory");
+ }
+
+ private double totalMemory() {
+ return count * memoryGb;
+ }
+
+ public boolean lessThan(GpuResources other) {
+ return totalMemory() < other.totalMemory();
+ }
+
+ public boolean isDefault() { return this.equals(getDefault()); }
+
+ public static GpuResources getDefault() { return none; }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ GpuResources that = (GpuResources) o;
+ return count == that.count && equal(this.memoryGb, that.memoryGb);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(count, memoryGb);
+ }
+
+ }
+
private final double vcpu;
private final double memoryGb;
private final double diskGb;
private final double bandwidthGbps;
+ private final GpuResources gpuResources;
private final DiskSpeed diskSpeed;
private final StorageType storageType;
private final Architecture architecture;
@@ -133,18 +171,23 @@ public class NodeResources {
}
public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed) {
- this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, StorageType.getDefault(), Architecture.getDefault());
+ this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, StorageType.getDefault(), Architecture.getDefault(), GpuResources.getDefault());
}
public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType) {
- this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, storageType, Architecture.getDefault());
+ this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, storageType, Architecture.getDefault(), GpuResources.getDefault());
}
public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType, Architecture architecture) {
+ this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, storageType, architecture, GpuResources.getDefault());
+ }
+
+ public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType, Architecture architecture, GpuResources gpuResources) {
this.vcpu = validate(vcpu, "vcpu");
this.memoryGb = validate(memoryGb, "memory");
this.diskGb = validate(diskGb, "disk");
this.bandwidthGbps = validate(bandwidthGbps, "bandwidth");
+ this.gpuResources = gpuResources;
this.diskSpeed = diskSpeed;
this.storageType = storageType;
this.architecture = architecture;
@@ -157,6 +200,7 @@ public class NodeResources {
public DiskSpeed diskSpeed() { return diskSpeed; }
public StorageType storageType() { return storageType; }
public Architecture architecture() { return architecture; }
+ public GpuResources gpuResources() { return gpuResources; }
/** Returns the standard cost of these resources, in dollars per hour */
public double cost() {
@@ -265,6 +309,7 @@ public class NodeResources {
if ( ! equal(this.memoryGb, other.memoryGb)) return false;
if ( ! equal(this.diskGb, other.diskGb)) return false;
if ( ! equal(this.bandwidthGbps, other.bandwidthGbps)) return false;
+ if ( ! this.gpuResources.equals(other.gpuResources)) return false;
if (this.diskSpeed != other.diskSpeed) return false;
if (this.storageType != other.storageType) return false;
if (this.architecture != other.architecture) return false;
@@ -306,6 +351,12 @@ public class NodeResources {
sb.append(", storage type: ").append(storageType);
}
sb.append(", architecture: ").append(architecture);
+ if ( !gpuResources.isDefault()) {
+ sb.append(", gpu count: ").append(gpuResources.count());
+ sb.append(", gpu memory: ");
+ appendDouble(sb, memoryGb);
+ sb.append(" Gb");
+ }
sb.append(']');
return sb.toString();
}
@@ -318,6 +369,7 @@ public class NodeResources {
if (this.memoryGb < other.memoryGb) return false;
if (this.diskGb < other.diskGb) return false;
if (this.bandwidthGbps < other.bandwidthGbps) return false;
+ if (this.gpuResources.lessThan(other.gpuResources)) return false;
// Why doesn't a fast disk satisfy a slow disk? Because if slow disk is explicitly specified
// (i.e not "any"), you should not randomly, sometimes get a faster disk as that means you may
@@ -339,6 +391,7 @@ public class NodeResources {
if ( ! equal(this.memoryGb, other.memoryGb)) return false;
if ( ! equal(this.diskGb, other.diskGb)) return false;
if ( ! equal(this.bandwidthGbps, other.bandwidthGbps)) return false;
+ if ( ! this.gpuResources.equals(other.gpuResources)) return false;
if ( ! this.diskSpeed.compatibleWith(other.diskSpeed)) return false;
if ( ! this.storageType.compatibleWith(other.storageType)) return false;
if ( ! this.architecture.compatibleWith(other.architecture)) return false;
@@ -371,7 +424,7 @@ public class NodeResources {
return this.isUnspecified() ? Optional.empty() : Optional.of(this);
}
- private boolean equal(double a, double b) {
+ private static boolean equal(double a, double b) {
return Math.abs(a - b) < 0.00000001;
}
@@ -396,7 +449,7 @@ public class NodeResources {
return new NodeResources(cpu, mem, dsk, 0.3, DiskSpeed.getDefault(), StorageType.getDefault(), Architecture.x86_64);
}
- private double validate(double value, String valueName) {
+ private static double validate(double value, String valueName) {
if (Double.isNaN(value)) throw new IllegalArgumentException(valueName + " cannot be NaN");
if (Double.isInfinite(value)) throw new IllegalArgumentException(valueName + " cannot be infinite");
return value;
diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializer.java b/config-provisioning/src/main/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializer.java
index 5db7303f4bf..f539bc19c49 100644
--- a/config-provisioning/src/main/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializer.java
+++ b/config-provisioning/src/main/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializer.java
@@ -13,9 +13,7 @@ import com.yahoo.slime.Slime;
import com.yahoo.slime.SlimeUtils;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.LinkedHashSet;
-import java.util.List;
import java.util.Optional;
import java.util.Set;
@@ -51,6 +49,8 @@ public class AllocatedHostsSerializer {
private static final String diskSpeedKey = "diskSpeed";
private static final String storageTypeKey = "storageType";
private static final String architectureKey = "architecture";
+ private static final String gpuCountKey = "gpuCount";
+ private static final String gpuMemoryKey = "gpuMemory";
/** Wanted version */
private static final String hostSpecVespaVersionKey = "vespaVersion";
@@ -97,6 +97,10 @@ public class AllocatedHostsSerializer {
resourcesObject.setString(diskSpeedKey, diskSpeedToString(resources.diskSpeed()));
resourcesObject.setString(storageTypeKey, storageTypeToString(resources.storageType()));
resourcesObject.setString(architectureKey, architectureToString(resources.architecture()));
+ if (!resources.gpuResources().isDefault()) {
+ resourcesObject.setLong(gpuCountKey, resources.gpuResources().count());
+ resourcesObject.setDouble(gpuMemoryKey, resources.gpuResources().memoryGb());
+ }
}
public static AllocatedHosts fromJson(byte[] json) {
@@ -113,7 +117,6 @@ public class AllocatedHostsSerializer {
}
private static HostSpec hostFromSlime(Inspector object) {
-
if (object.field(hostSpecMembershipKey).valid()) { // Hosted
return new HostSpec(object.field(hostSpecHostNameKey).asString(),
nodeResourcesFromSlime(object.field(realResourcesKey)),
@@ -137,7 +140,15 @@ public class AllocatedHostsSerializer {
resources.field(bandwidthKey).asDouble(),
diskSpeedFromSlime(resources.field(diskSpeedKey)),
storageTypeFromSlime(resources.field(storageTypeKey)),
- architectureFromSlime(resources.field(architectureKey)));
+ architectureFromSlime(resources.field(architectureKey)),
+ gpuResourcesFromSlime(resources));
+ }
+
+ private static NodeResources.GpuResources gpuResourcesFromSlime(Inspector resources) {
+ Inspector gpuCountField = resources.field(gpuCountKey);
+ Inspector gpuMemoryField = resources.field(gpuMemoryKey);
+ if (!gpuCountField.valid() || !gpuMemoryField.valid()) return NodeResources.GpuResources.getDefault();
+ return new NodeResources.GpuResources((int) gpuCountField.asLong(), gpuMemoryField.asDouble());
}
private static NodeResources optionalNodeResourcesFromSlime(Inspector resources) {
@@ -146,58 +157,55 @@ public class AllocatedHostsSerializer {
}
private static NodeResources.DiskSpeed diskSpeedFromSlime(Inspector diskSpeed) {
- switch (diskSpeed.asString()) {
- case "fast" : return NodeResources.DiskSpeed.fast;
- case "slow" : return NodeResources.DiskSpeed.slow;
- case "any" : return NodeResources.DiskSpeed.any;
- default: throw new IllegalStateException("Illegal disk-speed value '" + diskSpeed.asString() + "'");
- }
+ return switch (diskSpeed.asString()) {
+ case "fast" -> NodeResources.DiskSpeed.fast;
+ case "slow" -> NodeResources.DiskSpeed.slow;
+ case "any" -> NodeResources.DiskSpeed.any;
+ default -> throw new IllegalStateException("Illegal disk-speed value '" + diskSpeed.asString() + "'");
+ };
}
private static String diskSpeedToString(NodeResources.DiskSpeed diskSpeed) {
- switch (diskSpeed) {
- case fast : return "fast";
- case slow : return "slow";
- case any : return "any";
- default: throw new IllegalStateException("Illegal disk-speed value '" + diskSpeed + "'");
- }
+ return switch (diskSpeed) {
+ case fast -> "fast";
+ case slow -> "slow";
+ case any -> "any";
+ };
}
private static NodeResources.StorageType storageTypeFromSlime(Inspector storageType) {
- switch (storageType.asString()) {
- case "remote" : return NodeResources.StorageType.remote;
- case "local" : return NodeResources.StorageType.local;
- case "any" : return NodeResources.StorageType.any;
- default: throw new IllegalStateException("Illegal storage-type value '" + storageType.asString() + "'");
- }
+ return switch (storageType.asString()) {
+ case "remote" -> NodeResources.StorageType.remote;
+ case "local" -> NodeResources.StorageType.local;
+ case "any" -> NodeResources.StorageType.any;
+ default -> throw new IllegalStateException("Illegal storage-type value '" + storageType.asString() + "'");
+ };
}
private static String storageTypeToString(NodeResources.StorageType storageType) {
- switch (storageType) {
- case remote : return "remote";
- case local : return "local";
- case any : return "any";
- default: throw new IllegalStateException("Illegal storage-type value '" + storageType + "'");
- }
+ return switch (storageType) {
+ case remote -> "remote";
+ case local -> "local";
+ case any -> "any";
+ };
}
private static NodeResources.Architecture architectureFromSlime(Inspector architecture) {
if ( ! architecture.valid()) return NodeResources.Architecture.x86_64;
- switch (architecture.asString()) {
- case "x86_64" : return NodeResources.Architecture.x86_64;
- case "arm64" : return NodeResources.Architecture.arm64;
- case "any" : return NodeResources.Architecture.any;
- default: throw new IllegalStateException("Illegal architecture value '" + architecture.asString() + "'");
- }
+ return switch (architecture.asString()) {
+ case "x86_64" -> NodeResources.Architecture.x86_64;
+ case "arm64" -> NodeResources.Architecture.arm64;
+ case "any" -> NodeResources.Architecture.any;
+ default -> throw new IllegalStateException("Illegal architecture value '" + architecture.asString() + "'");
+ };
}
private static String architectureToString(NodeResources.Architecture architecture) {
- switch (architecture) {
- case x86_64: return "x86_64";
- case arm64: return "arm64";
- case any : return "any";
- default: throw new IllegalStateException("Illegal architecture value '" + architecture + "'");
- }
+ return switch (architecture) {
+ case x86_64 -> "x86_64";
+ case arm64 -> "arm64";
+ case any -> "any";
+ };
}
private static ClusterMembership membershipFromSlime(Inspector object) {
diff --git a/config-provisioning/src/main/resources/configdefinitions/config.provisioning.flavors.def b/config-provisioning/src/main/resources/configdefinitions/config.provisioning.flavors.def
index 765d6c2f812..b36069ebb57 100644
--- a/config-provisioning/src/main/resources/configdefinitions/config.provisioning.flavors.def
+++ b/config-provisioning/src/main/resources/configdefinitions/config.provisioning.flavors.def
@@ -38,3 +38,7 @@ flavor[].bandwidth double default=0.0
# The architecture for this flavor
flavor[].architecture string default="x86_64"
+
+# The GPU count and GPU memory (per GPU) of this flavor
+flavor[].gpuCount int default=0
+flavor[].gpuMemoryGb double default=0.0
diff --git a/config-provisioning/src/test/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializerTest.java b/config-provisioning/src/test/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializerTest.java
index fec8be59c50..172e0875767 100644
--- a/config-provisioning/src/test/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializerTest.java
+++ b/config-provisioning/src/test/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializerTest.java
@@ -27,7 +27,8 @@ public class AllocatedHostsSerializerTest {
private static final NodeResources smallSlowDiskSpeedNode = new NodeResources(0.5, 3.1, 4, 1, NodeResources.DiskSpeed.slow);
private static final NodeResources bigSlowDiskSpeedNode = new NodeResources(1.0, 6.2, 8, 2, NodeResources.DiskSpeed.slow);
- private static final NodeResources anyDiskSpeedNode = new NodeResources(0.5, 3.1, 4, 1, NodeResources.DiskSpeed.any);
+ private static final NodeResources anyDiskSpeedNode = new NodeResources(0.5, 3.1, 4, 1, NodeResources.DiskSpeed.any, NodeResources.StorageType.local,
+ NodeResources.Architecture.x86_64, new NodeResources.GpuResources(1, 16));
private static final NodeResources arm64Node = new NodeResources(0.5, 3.1, 4, 1, NodeResources.DiskSpeed.any, NodeResources.StorageType.any, NodeResources.Architecture.arm64);
@Test