From de749a8872fd1d96b7ae7b194b80764abe2768ed Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Wed, 16 Nov 2022 15:36:09 +0100 Subject: Support GPU in node specification --- .../java/com/yahoo/config/provision/Capacity.java | 3 + .../java/com/yahoo/config/provision/Flavor.java | 3 +- .../com/yahoo/config/provision/NodeResources.java | 61 ++++++++++++++- .../serialization/AllocatedHostsSerializer.java | 88 ++++++++++++---------- .../config.provisioning.flavors.def | 4 + .../AllocatedHostsSerializerTest.java | 3 +- 6 files changed, 116 insertions(+), 46 deletions(-) (limited to 'config-provisioning') diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java b/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java index 70e88418fb7..dd8d181e1df 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java @@ -3,6 +3,7 @@ package com.yahoo.config.provision; import java.util.Objects; import java.util.Optional; +import java.util.stream.Stream; /** * A capacity request. @@ -25,6 +26,8 @@ public final class Capacity { if (max.smallerThan(min)) throw new IllegalArgumentException("The max capacity must be larger than the min capacity, but got min " + min + " and max " + max); + if (!min.equals(max) && Stream.of(min, max).anyMatch(cr -> !cr.nodeResources().gpuResources().isDefault())) + throw new IllegalArgumentException("Capacity range does not allow GPU, got min " + min + " and max " + max); this.min = min; this.max = max; this.required = required; diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/Flavor.java b/config-provisioning/src/main/java/com/yahoo/config/provision/Flavor.java index 1ae4974a4c2..4539aec58eb 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/Flavor.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/Flavor.java @@ -38,7 +38,8 @@ public class Flavor { flavorConfig.bandwidth() / 1000, flavorConfig.fastDisk() ? NodeResources.DiskSpeed.fast : NodeResources.DiskSpeed.slow, flavorConfig.remoteStorage() ? NodeResources.StorageType.remote : NodeResources.StorageType.local, - NodeResources.Architecture.valueOf(flavorConfig.architecture())), + NodeResources.Architecture.valueOf(flavorConfig.architecture()), + new NodeResources.GpuResources(flavorConfig.gpuCount(), flavorConfig.gpuMemoryGb())), Optional.empty(), Type.valueOf(flavorConfig.environment()), true, diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java index 25771f1906b..e6bd7c70a82 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java @@ -120,10 +120,48 @@ public class NodeResources { } + public record GpuResources(int count, double memoryGb) { + + private static final GpuResources none = new GpuResources(0, 0); + + public GpuResources { + if (count < 0) throw new IllegalArgumentException("GPU count cannot be negative, got " + count); + if (memoryGb < 0) throw new IllegalArgumentException("GPU memory cannot be negative, got " + memoryGb); + validate(memoryGb, "memory"); + } + + private double totalMemory() { + return count * memoryGb; + } + + public boolean lessThan(GpuResources other) { + return totalMemory() < other.totalMemory(); + } + + public boolean isDefault() { return this.equals(getDefault()); } + + public static GpuResources getDefault() { return none; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + GpuResources that = (GpuResources) o; + return count == that.count && equal(this.memoryGb, that.memoryGb); + } + + @Override + public int hashCode() { + return Objects.hash(count, memoryGb); + } + + } + private final double vcpu; private final double memoryGb; private final double diskGb; private final double bandwidthGbps; + private final GpuResources gpuResources; private final DiskSpeed diskSpeed; private final StorageType storageType; private final Architecture architecture; @@ -133,18 +171,23 @@ public class NodeResources { } public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed) { - this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, StorageType.getDefault(), Architecture.getDefault()); + this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, StorageType.getDefault(), Architecture.getDefault(), GpuResources.getDefault()); } public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType) { - this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, storageType, Architecture.getDefault()); + this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, storageType, Architecture.getDefault(), GpuResources.getDefault()); } public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType, Architecture architecture) { + this(vcpu, memoryGb, diskGb, bandwidthGbps, diskSpeed, storageType, architecture, GpuResources.getDefault()); + } + + public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType, Architecture architecture, GpuResources gpuResources) { this.vcpu = validate(vcpu, "vcpu"); this.memoryGb = validate(memoryGb, "memory"); this.diskGb = validate(diskGb, "disk"); this.bandwidthGbps = validate(bandwidthGbps, "bandwidth"); + this.gpuResources = gpuResources; this.diskSpeed = diskSpeed; this.storageType = storageType; this.architecture = architecture; @@ -157,6 +200,7 @@ public class NodeResources { public DiskSpeed diskSpeed() { return diskSpeed; } public StorageType storageType() { return storageType; } public Architecture architecture() { return architecture; } + public GpuResources gpuResources() { return gpuResources; } /** Returns the standard cost of these resources, in dollars per hour */ public double cost() { @@ -265,6 +309,7 @@ public class NodeResources { if ( ! equal(this.memoryGb, other.memoryGb)) return false; if ( ! equal(this.diskGb, other.diskGb)) return false; if ( ! equal(this.bandwidthGbps, other.bandwidthGbps)) return false; + if ( ! this.gpuResources.equals(other.gpuResources)) return false; if (this.diskSpeed != other.diskSpeed) return false; if (this.storageType != other.storageType) return false; if (this.architecture != other.architecture) return false; @@ -306,6 +351,12 @@ public class NodeResources { sb.append(", storage type: ").append(storageType); } sb.append(", architecture: ").append(architecture); + if ( !gpuResources.isDefault()) { + sb.append(", gpu count: ").append(gpuResources.count()); + sb.append(", gpu memory: "); + appendDouble(sb, memoryGb); + sb.append(" Gb"); + } sb.append(']'); return sb.toString(); } @@ -318,6 +369,7 @@ public class NodeResources { if (this.memoryGb < other.memoryGb) return false; if (this.diskGb < other.diskGb) return false; if (this.bandwidthGbps < other.bandwidthGbps) return false; + if (this.gpuResources.lessThan(other.gpuResources)) return false; // Why doesn't a fast disk satisfy a slow disk? Because if slow disk is explicitly specified // (i.e not "any"), you should not randomly, sometimes get a faster disk as that means you may @@ -339,6 +391,7 @@ public class NodeResources { if ( ! equal(this.memoryGb, other.memoryGb)) return false; if ( ! equal(this.diskGb, other.diskGb)) return false; if ( ! equal(this.bandwidthGbps, other.bandwidthGbps)) return false; + if ( ! this.gpuResources.equals(other.gpuResources)) return false; if ( ! this.diskSpeed.compatibleWith(other.diskSpeed)) return false; if ( ! this.storageType.compatibleWith(other.storageType)) return false; if ( ! this.architecture.compatibleWith(other.architecture)) return false; @@ -371,7 +424,7 @@ public class NodeResources { return this.isUnspecified() ? Optional.empty() : Optional.of(this); } - private boolean equal(double a, double b) { + private static boolean equal(double a, double b) { return Math.abs(a - b) < 0.00000001; } @@ -396,7 +449,7 @@ public class NodeResources { return new NodeResources(cpu, mem, dsk, 0.3, DiskSpeed.getDefault(), StorageType.getDefault(), Architecture.x86_64); } - private double validate(double value, String valueName) { + private static double validate(double value, String valueName) { if (Double.isNaN(value)) throw new IllegalArgumentException(valueName + " cannot be NaN"); if (Double.isInfinite(value)) throw new IllegalArgumentException(valueName + " cannot be infinite"); return value; diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializer.java b/config-provisioning/src/main/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializer.java index 5db7303f4bf..f539bc19c49 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializer.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializer.java @@ -13,9 +13,7 @@ import com.yahoo.slime.Slime; import com.yahoo.slime.SlimeUtils; import java.io.IOException; -import java.util.ArrayList; import java.util.LinkedHashSet; -import java.util.List; import java.util.Optional; import java.util.Set; @@ -51,6 +49,8 @@ public class AllocatedHostsSerializer { private static final String diskSpeedKey = "diskSpeed"; private static final String storageTypeKey = "storageType"; private static final String architectureKey = "architecture"; + private static final String gpuCountKey = "gpuCount"; + private static final String gpuMemoryKey = "gpuMemory"; /** Wanted version */ private static final String hostSpecVespaVersionKey = "vespaVersion"; @@ -97,6 +97,10 @@ public class AllocatedHostsSerializer { resourcesObject.setString(diskSpeedKey, diskSpeedToString(resources.diskSpeed())); resourcesObject.setString(storageTypeKey, storageTypeToString(resources.storageType())); resourcesObject.setString(architectureKey, architectureToString(resources.architecture())); + if (!resources.gpuResources().isDefault()) { + resourcesObject.setLong(gpuCountKey, resources.gpuResources().count()); + resourcesObject.setDouble(gpuMemoryKey, resources.gpuResources().memoryGb()); + } } public static AllocatedHosts fromJson(byte[] json) { @@ -113,7 +117,6 @@ public class AllocatedHostsSerializer { } private static HostSpec hostFromSlime(Inspector object) { - if (object.field(hostSpecMembershipKey).valid()) { // Hosted return new HostSpec(object.field(hostSpecHostNameKey).asString(), nodeResourcesFromSlime(object.field(realResourcesKey)), @@ -137,7 +140,15 @@ public class AllocatedHostsSerializer { resources.field(bandwidthKey).asDouble(), diskSpeedFromSlime(resources.field(diskSpeedKey)), storageTypeFromSlime(resources.field(storageTypeKey)), - architectureFromSlime(resources.field(architectureKey))); + architectureFromSlime(resources.field(architectureKey)), + gpuResourcesFromSlime(resources)); + } + + private static NodeResources.GpuResources gpuResourcesFromSlime(Inspector resources) { + Inspector gpuCountField = resources.field(gpuCountKey); + Inspector gpuMemoryField = resources.field(gpuMemoryKey); + if (!gpuCountField.valid() || !gpuMemoryField.valid()) return NodeResources.GpuResources.getDefault(); + return new NodeResources.GpuResources((int) gpuCountField.asLong(), gpuMemoryField.asDouble()); } private static NodeResources optionalNodeResourcesFromSlime(Inspector resources) { @@ -146,58 +157,55 @@ public class AllocatedHostsSerializer { } private static NodeResources.DiskSpeed diskSpeedFromSlime(Inspector diskSpeed) { - switch (diskSpeed.asString()) { - case "fast" : return NodeResources.DiskSpeed.fast; - case "slow" : return NodeResources.DiskSpeed.slow; - case "any" : return NodeResources.DiskSpeed.any; - default: throw new IllegalStateException("Illegal disk-speed value '" + diskSpeed.asString() + "'"); - } + return switch (diskSpeed.asString()) { + case "fast" -> NodeResources.DiskSpeed.fast; + case "slow" -> NodeResources.DiskSpeed.slow; + case "any" -> NodeResources.DiskSpeed.any; + default -> throw new IllegalStateException("Illegal disk-speed value '" + diskSpeed.asString() + "'"); + }; } private static String diskSpeedToString(NodeResources.DiskSpeed diskSpeed) { - switch (diskSpeed) { - case fast : return "fast"; - case slow : return "slow"; - case any : return "any"; - default: throw new IllegalStateException("Illegal disk-speed value '" + diskSpeed + "'"); - } + return switch (diskSpeed) { + case fast -> "fast"; + case slow -> "slow"; + case any -> "any"; + }; } private static NodeResources.StorageType storageTypeFromSlime(Inspector storageType) { - switch (storageType.asString()) { - case "remote" : return NodeResources.StorageType.remote; - case "local" : return NodeResources.StorageType.local; - case "any" : return NodeResources.StorageType.any; - default: throw new IllegalStateException("Illegal storage-type value '" + storageType.asString() + "'"); - } + return switch (storageType.asString()) { + case "remote" -> NodeResources.StorageType.remote; + case "local" -> NodeResources.StorageType.local; + case "any" -> NodeResources.StorageType.any; + default -> throw new IllegalStateException("Illegal storage-type value '" + storageType.asString() + "'"); + }; } private static String storageTypeToString(NodeResources.StorageType storageType) { - switch (storageType) { - case remote : return "remote"; - case local : return "local"; - case any : return "any"; - default: throw new IllegalStateException("Illegal storage-type value '" + storageType + "'"); - } + return switch (storageType) { + case remote -> "remote"; + case local -> "local"; + case any -> "any"; + }; } private static NodeResources.Architecture architectureFromSlime(Inspector architecture) { if ( ! architecture.valid()) return NodeResources.Architecture.x86_64; - switch (architecture.asString()) { - case "x86_64" : return NodeResources.Architecture.x86_64; - case "arm64" : return NodeResources.Architecture.arm64; - case "any" : return NodeResources.Architecture.any; - default: throw new IllegalStateException("Illegal architecture value '" + architecture.asString() + "'"); - } + return switch (architecture.asString()) { + case "x86_64" -> NodeResources.Architecture.x86_64; + case "arm64" -> NodeResources.Architecture.arm64; + case "any" -> NodeResources.Architecture.any; + default -> throw new IllegalStateException("Illegal architecture value '" + architecture.asString() + "'"); + }; } private static String architectureToString(NodeResources.Architecture architecture) { - switch (architecture) { - case x86_64: return "x86_64"; - case arm64: return "arm64"; - case any : return "any"; - default: throw new IllegalStateException("Illegal architecture value '" + architecture + "'"); - } + return switch (architecture) { + case x86_64 -> "x86_64"; + case arm64 -> "arm64"; + case any -> "any"; + }; } private static ClusterMembership membershipFromSlime(Inspector object) { diff --git a/config-provisioning/src/main/resources/configdefinitions/config.provisioning.flavors.def b/config-provisioning/src/main/resources/configdefinitions/config.provisioning.flavors.def index 765d6c2f812..b36069ebb57 100644 --- a/config-provisioning/src/main/resources/configdefinitions/config.provisioning.flavors.def +++ b/config-provisioning/src/main/resources/configdefinitions/config.provisioning.flavors.def @@ -38,3 +38,7 @@ flavor[].bandwidth double default=0.0 # The architecture for this flavor flavor[].architecture string default="x86_64" + +# The GPU count and GPU memory (per GPU) of this flavor +flavor[].gpuCount int default=0 +flavor[].gpuMemoryGb double default=0.0 diff --git a/config-provisioning/src/test/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializerTest.java b/config-provisioning/src/test/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializerTest.java index fec8be59c50..172e0875767 100644 --- a/config-provisioning/src/test/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializerTest.java +++ b/config-provisioning/src/test/java/com/yahoo/config/provision/serialization/AllocatedHostsSerializerTest.java @@ -27,7 +27,8 @@ public class AllocatedHostsSerializerTest { private static final NodeResources smallSlowDiskSpeedNode = new NodeResources(0.5, 3.1, 4, 1, NodeResources.DiskSpeed.slow); private static final NodeResources bigSlowDiskSpeedNode = new NodeResources(1.0, 6.2, 8, 2, NodeResources.DiskSpeed.slow); - private static final NodeResources anyDiskSpeedNode = new NodeResources(0.5, 3.1, 4, 1, NodeResources.DiskSpeed.any); + private static final NodeResources anyDiskSpeedNode = new NodeResources(0.5, 3.1, 4, 1, NodeResources.DiskSpeed.any, NodeResources.StorageType.local, + NodeResources.Architecture.x86_64, new NodeResources.GpuResources(1, 16)); private static final NodeResources arm64Node = new NodeResources(0.5, 3.1, 4, 1, NodeResources.DiskSpeed.any, NodeResources.StorageType.any, NodeResources.Architecture.arm64); @Test -- cgit v1.2.3