summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2023-11-21 09:52:52 +0100
committerGitHub <noreply@github.com>2023-11-21 09:52:52 +0100
commit430c0f8c9e1ea5eaeae2b795cd4b7350091679ae (patch)
tree4e49701083ff1600df5775f481eb3057edbf88bf
parentd998b2774ce916ce5a92f4879f3f47a23f1346a9 (diff)
parent9d28a47b003f5498bc59bfd10017dd55fc7ab6e0 (diff)
Merge pull request #29388 from vespa-engine/hmusum/register-with-onnx-model-options
Register model with onnx model options
-rw-r--r--config-model-api/abi-spec.json53
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java5
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java (renamed from config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java)8
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java11
11 files changed, 87 insertions, 16 deletions
diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json
index d9c68c89189..78b32d8af7b 100644
--- a/config-model-api/abi-spec.json
+++ b/config-model-api/abi-spec.json
@@ -1453,7 +1453,9 @@
"methods" : [
"public abstract long aggregatedModelCostInBytes()",
"public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile)",
- "public abstract void registerModel(java.net.URI)"
+ "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)",
+ "public abstract void registerModel(java.net.URI)",
+ "public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)"
],
"fields" : [ ]
},
@@ -1471,7 +1473,9 @@
"public com.yahoo.config.model.api.OnnxModelCost$Calculator newCalculator(com.yahoo.config.application.api.ApplicationPackage, com.yahoo.config.provision.ApplicationId)",
"public long aggregatedModelCostInBytes()",
"public void registerModel(com.yahoo.config.application.api.ApplicationFile)",
- "public void registerModel(java.net.URI)"
+ "public void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)",
+ "public void registerModel(java.net.URI)",
+ "public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)"
],
"fields" : [ ]
},
@@ -1489,6 +1493,51 @@
],
"fields" : [ ]
},
+ "com.yahoo.config.model.api.OnnxModelOptions$GpuDevice" : {
+ "superClass" : "java.lang.Record",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final",
+ "record"
+ ],
+ "methods" : [
+ "public void <init>(int, boolean)",
+ "public void <init>(int)",
+ "public final java.lang.String toString()",
+ "public final int hashCode()",
+ "public final boolean equals(java.lang.Object)",
+ "public int deviceNumber()",
+ "public boolean required()"
+ ],
+ "fields" : [ ]
+ },
+ "com.yahoo.config.model.api.OnnxModelOptions" : {
+ "superClass" : "java.lang.Record",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final",
+ "record"
+ ],
+ "methods" : [
+ "public void <init>(java.lang.String, int, int, com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)",
+ "public void <init>(java.util.Optional, java.util.Optional, java.util.Optional, java.util.Optional)",
+ "public static com.yahoo.config.model.api.OnnxModelOptions empty()",
+ "public com.yahoo.config.model.api.OnnxModelOptions withExecutionMode(java.lang.String)",
+ "public com.yahoo.config.model.api.OnnxModelOptions withInterOpThreads(java.lang.Integer)",
+ "public com.yahoo.config.model.api.OnnxModelOptions withIntraOpThreads(java.lang.Integer)",
+ "public com.yahoo.config.model.api.OnnxModelOptions withGpuDevice(com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)",
+ "public final java.lang.String toString()",
+ "public final int hashCode()",
+ "public final boolean equals(java.lang.Object)",
+ "public java.util.Optional executionMode()",
+ "public java.util.Optional interOpThreads()",
+ "public java.util.Optional intraOpThreads()",
+ "public java.util.Optional gpuDevice()"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.config.model.api.PortInfo" : {
"superClass" : "java.lang.Object",
"interfaces" : [ ],
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
index acb88070482..b98667457e4 100644
--- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
@@ -10,6 +10,7 @@ import java.net.URI;
/**
* @author bjorncs
*/
+// TODO: Rename
public interface OnnxModelCost {
Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId);
@@ -17,7 +18,9 @@ public interface OnnxModelCost {
interface Calculator {
long aggregatedModelCostInBytes();
void registerModel(ApplicationFile path);
+ void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions);
void registerModel(URI uri);
+ void registerModel(URI uri, OnnxModelOptions onnxModelOptions);
}
static OnnxModelCost disabled() { return new DisabledOnnxModelCost(); }
@@ -26,7 +29,9 @@ public interface OnnxModelCost {
@Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; }
@Override public long aggregatedModelCostInBytes() {return 0;}
@Override public void registerModel(ApplicationFile path) {}
+ @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {}
@Override public void registerModel(URI uri) {}
+ @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {}
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java
index 6347f0dc427..92817baae3f 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java
@@ -1,5 +1,5 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.model.container.component;
+package com.yahoo.config.model.api;
import java.util.Optional;
@@ -12,7 +12,11 @@ import java.util.Optional;
public record OnnxModelOptions(Optional<String> executionMode, Optional<Integer> interOpThreads,
Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice) {
- public static OnnxModelOptions empty() {
+ public OnnxModelOptions(String executionMode, int interOpThreads, int intraOpThreads, GpuDevice gpuDevice) {
+ this(Optional.of(executionMode), Optional.of(interOpThreads), Optional.of(intraOpThreads), Optional.of(gpuDevice));
+ }
+
+ public static OnnxModelOptions empty() {
return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
}
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 867ffdb3960..9456baafd57 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -1,9 +1,9 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;
+import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
-import com.yahoo.vespa.model.container.component.OnnxModelOptions;
import com.yahoo.vespa.model.ml.OnnxModelInfo;
import java.util.Collections;
@@ -171,4 +171,6 @@ public class OnnxModel extends DistributableResource implements Cloneable {
return onnxModelOptions.gpuDevice();
}
+ public OnnxModelOptions onnxModelOptions() { return onnxModelOptions; }
+
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
index ea3caadc23a..67fb720b8c0 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
@@ -47,7 +48,7 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf
transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
- model.registerOnnxModelCost(cluster);
+ model.registerOnnxModelCost(cluster, onnxModelOptions);
}
@Override
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index cbae50b400c..d22e6afc3d1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
@@ -55,7 +56,7 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
- model.registerOnnxModelCost(cluster);
+ model.registerOnnxModelCost(cluster, onnxModelOptions);
}
private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index d1bd0dce000..d98c72ab3a4 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
@@ -48,7 +49,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
- model.registerOnnxModelCost(cluster);
+ model.registerOnnxModelCost(cluster, onnxModelOptions);
}
private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
index c5daf23d6f8..0d350242fd0 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
@@ -4,6 +4,7 @@ package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.builder.xml.XmlHelper;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.path.Path;
@@ -54,10 +55,10 @@ class Model {
return new Model(ds, model.getTagName(), modelId, url, path);
}
- void registerOnnxModelCost(ApplicationContainerCluster c) {
+ void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) {
var resolvedUrl = resolvedUrl().orElse(null);
- if (file != null) c.onnxModelCost().registerModel(file);
- else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl);
+ if (file != null) c.onnxModelCost().registerModel(file, onnxModelOptions);
+ else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl, onnxModelOptions);
}
String name() { return paramName; }
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java
index d86d117f1d2..31468c05b99 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java
@@ -52,11 +52,11 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains>
private final List<SearchCluster> searchClusters = new LinkedList<>();
private final Collection<String> schemasWithGlobalPhase;
private final boolean globalPhase;
+ private final ApplicationPackage app;
private QueryProfiles queryProfiles;
private SemanticRules semanticRules;
private PageTemplates pageTemplates;
- private ApplicationPackage app;
public ContainerSearch(DeployState deployState, ApplicationContainerCluster cluster, SearchChains chains) {
super(chains);
@@ -102,7 +102,7 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains>
if ( ! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) {
var onnxModels = documentDb.getDerivedConfiguration().getRankProfileList().getOnnxModels();
onnxModels.asMap().forEach(
- (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath())));
+ (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions()));
owningCluster.addComponent(factory);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index 18020f5df5d..5ffd34c6557 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -800,7 +800,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
!container.getHostResource().realResources().gpuResources().isZero());
onnxModel.setGpuDevice(gpuDevice, hasGpu);
}
- cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()));
+ cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions());
}
cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models));
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java
index 8531aff3b1a..9cadf5cffd8 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java
@@ -1,14 +1,13 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
package com.yahoo.vespa.model.application.validation;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.NullConfigModelRegistry;
import com.yahoo.config.model.api.ApplicationClusterEndpoint;
import com.yahoo.config.model.api.ContainerEndpoint;
import com.yahoo.config.model.api.OnnxModelCost;
+import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.config.model.provision.InMemoryProvisioner;
@@ -123,12 +122,20 @@ class JvmHeapSizeValidatorTest {
@Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; }
@Override public long aggregatedModelCostInBytes() { return totalCost.get(); }
@Override public void registerModel(ApplicationFile path) {}
+ @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {}
@Override
public void registerModel(URI uri) {
assertEquals("https://my/url/model.onnx", uri.toString());
totalCost.addAndGet(modelCost);
}
+
+ @Override
+ public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {
+ assertEquals("https://my/url/model.onnx", uri.toString());
+ totalCost.addAndGet(modelCost);
+ }
+
}
}