aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-09 14:59:49 +0200
committerGitHub <noreply@github.com>2023-05-09 14:59:49 +0200
commit05fd745f30f48cde26b8db13c08e47a1b301439d (patch)
treed25cac76d37fcfa281240fa396f9974902b3834f /config-model
parent09898990dd5ba74d09da74a56f137e2fff505be5 (diff)
parent372bd2c677bb9707c55a9153f860fb2017ce6ffc (diff)
Merge pull request #27045 from vespa-engine/bjorncs/embedder-onnx-gpu
Reapply "Bjorncs/embedder onnx gpu"
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainer.java27
-rw-r--r--config-model/src/test/derived/globalphase_onnx_inside/onnx-models.cfg1
-rw-r--r--config-model/src/test/derived/globalphase_token_functions/onnx-models.cfg1
-rw-r--r--config-model/src/test/derived/vector_constant/onnx-models.cfg1
4 files changed, 24 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainer.java
index f901bf3c826..9e21fd2d23a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainer.java
@@ -9,6 +9,7 @@ import com.yahoo.config.model.producer.TreeConfigProducer;
import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.search.config.QrStartConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.model.container.component.SimpleComponent;
import java.time.Duration;
import java.util.Optional;
@@ -20,6 +21,7 @@ import java.util.Optional;
*/
public final class ApplicationContainer extends Container implements
QrStartConfig.Producer,
+ OnnxModelsConfig.Producer,
ZookeeperServerConfig.Producer {
private final boolean isHostedVespa;
@@ -42,12 +44,15 @@ public final class ApplicationContainer extends Container implements
@Override
public void getConfig(QrStartConfig.Builder builder) {
- if (getHostResource() != null) {
- NodeResources nodeResources = getHostResource().realResources();
- if ( ! nodeResources.isUnspecified()) {
- builder.jvm.availableProcessors(Math.max(2, (int)Math.ceil(nodeResources.vcpu())));
- }
- }
+ realResources().ifPresent(r -> builder.jvm.availableProcessors(Math.max(2, (int) Math.ceil(r.vcpu()))));
+ }
+
+ @Override
+ public void getConfig(OnnxModelsConfig.Builder builder) {
+ realResources().ifPresent(r -> {
+ int count = r.gpuResources().count();
+ if (count >= 0) builder.gpu.count(count);
+ });
}
@Override
@@ -84,4 +89,14 @@ public final class ApplicationContainer extends Container implements
@Override public Optional<String> getPreShutdownCommand() { return Optional.of(prepareStopCommand(Duration.ofMinutes(6))); }
+ private Optional<NodeResources> realResources() {
+ if (getHostResource() != null) {
+ NodeResources nodeResources = getHostResource().realResources();
+ if ( ! nodeResources.isUnspecified()) {
+ return Optional.of(nodeResources);
+ }
+ }
+ return Optional.empty();
+ }
+
}
diff --git a/config-model/src/test/derived/globalphase_onnx_inside/onnx-models.cfg b/config-model/src/test/derived/globalphase_onnx_inside/onnx-models.cfg
index d63e85e2f19..99f65336794 100644
--- a/config-model/src/test/derived/globalphase_onnx_inside/onnx-models.cfg
+++ b/config-model/src/test/derived/globalphase_onnx_inside/onnx-models.cfg
@@ -1,3 +1,4 @@
+gpu.count -1
model[].name "direct"
model[].fileref "files/ax_plus_b.onnx"
model[].input[].name "vector_B"
diff --git a/config-model/src/test/derived/globalphase_token_functions/onnx-models.cfg b/config-model/src/test/derived/globalphase_token_functions/onnx-models.cfg
index 6283159c324..cea4c065014 100644
--- a/config-model/src/test/derived/globalphase_token_functions/onnx-models.cfg
+++ b/config-model/src/test/derived/globalphase_token_functions/onnx-models.cfg
@@ -1,3 +1,4 @@
+gpu.count -1
model[].name "my_ranking_model"
model[].fileref "files/ranking_model.onnx"
model[].input[].name "input_ids"
diff --git a/config-model/src/test/derived/vector_constant/onnx-models.cfg b/config-model/src/test/derived/vector_constant/onnx-models.cfg
index 4c52b72b519..1dcaf0e1bd6 100644
--- a/config-model/src/test/derived/vector_constant/onnx-models.cfg
+++ b/config-model/src/test/derived/vector_constant/onnx-models.cfg
@@ -1,3 +1,4 @@
+gpu.count -1
model[].name "inside"
model[].fileref "ax_plus_b.onnx"
model[].input[].name "vector_B"