summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2023-11-16 09:27:02 +0100
committerHarald Musum <musum@yahooinc.com>2023-11-16 09:27:02 +0100
commitae5b333a133add3fc6d7ab6c056379c5b4f1564d (patch)
tree094bc2c57b5fa64a8c7850b793afcf3a8dd05750 /config-model
parent8c93087d6df4e995e61f17a92ab462e135608225 (diff)
Add OnnxModelOptions and use it in embedders
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java32
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java77
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java89
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java76
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java84
5 files changed, 201 insertions, 157 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index f3f09150c1d..b22e737fd6a 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -3,6 +3,7 @@ package com.yahoo.schema;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.model.container.component.OnnxModelOptions;
import com.yahoo.vespa.model.ml.OnnxModelInfo;
import java.util.Collections;
@@ -27,10 +28,7 @@ public class OnnxModel extends DistributableResource implements Cloneable {
private final Set<String> initializers = new HashSet<>();
// Runtime options
- private String statelessExecutionMode = null;
- private Integer statelessInterOpThreads = null;
- private Integer statelessIntraOpThreads = null;
- private GpuDevice gpuDevice = null;
+ private OnnxModelOptions onnxModelOptions = OnnxModelOptions.empty();
public OnnxModel(String name) {
super(name);
@@ -133,50 +131,44 @@ public class OnnxModel extends DistributableResource implements Cloneable {
public void setStatelessExecutionMode(String executionMode) {
if ("parallel".equalsIgnoreCase(executionMode)) {
- this.statelessExecutionMode = "parallel";
+ onnxModelOptions = onnxModelOptions.withExecutionMode("parallel");
} else if ("sequential".equalsIgnoreCase(executionMode)) {
- this.statelessExecutionMode = "sequential";
+ onnxModelOptions = onnxModelOptions.withExecutionMode("sequential");
}
}
public Optional<String> getStatelessExecutionMode() {
- return Optional.ofNullable(statelessExecutionMode);
+ return Optional.ofNullable(onnxModelOptions.executionMode());
}
public void setStatelessInterOpThreads(int interOpThreads) {
if (interOpThreads >= 0) {
- this.statelessInterOpThreads = interOpThreads;
+ onnxModelOptions = onnxModelOptions.withInteropThreads(interOpThreads);
}
}
public Optional<Integer> getStatelessInterOpThreads() {
- return Optional.ofNullable(statelessInterOpThreads);
+ return Optional.ofNullable(onnxModelOptions.interOpThreads());
}
public void setStatelessIntraOpThreads(int intraOpThreads) {
if (intraOpThreads >= 0) {
- this.statelessIntraOpThreads = intraOpThreads;
+ onnxModelOptions = onnxModelOptions.withIntraopThreads(intraOpThreads);
}
}
public Optional<Integer> getStatelessIntraOpThreads() {
- return Optional.ofNullable(statelessIntraOpThreads);
+ return Optional.ofNullable(onnxModelOptions.intraOpThreads());
}
public void setGpuDevice(int deviceNumber, boolean required) {
if (deviceNumber >= 0) {
- this.gpuDevice = new GpuDevice(deviceNumber, required);
+ onnxModelOptions = onnxModelOptions.withGpuDevice(new OnnxModelOptions.GpuDevice(deviceNumber, required));
}
}
- public Optional<GpuDevice> getGpuDevice() {
- return Optional.ofNullable(gpuDevice);
- }
-
- public record GpuDevice(int deviceNumber, boolean required) {
- public GpuDevice {
- if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
- }
+ public Optional<OnnxModelOptions.GpuDevice> getGpuDevice() {
+ return Optional.ofNullable(onnxModelOptions.gpuDevice());
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
index a644382625b..67edb40b4d3 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -1,13 +1,13 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
package com.yahoo.vespa.model.container.component;
-import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode;
+import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
@@ -16,56 +16,45 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
*/
public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer {
- private final ModelReference modelRef;
- private final ModelReference vocabRef;
- private final Integer maxTokens;
- private final String transformerInputIds;
- private final String transformerAttentionMask;
- private final String transformerTokenTypeIds;
- private final String transformerOutput;
- private final Integer tranformerStartSequenceToken;
- private final Integer transformerEndSequenceToken;
- private final String poolingStrategy;
- private final String onnxExecutionMode;
- private final Integer onnxInteropThreads;
- private final Integer onnxIntraopThreads;
- private final Integer onnxGpuDevice;
-
+ private final OnnxModelOptions onnxModelOptions;
public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
- modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference();
- maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
- transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
- transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
- transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null);
- transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
- tranformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
- transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
- poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
- onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null);
- onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
- onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
- onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ this.onnxModelOptions = new OnnxModelOptions(
+ model.modelReference(),
+ Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(),
+ getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "transformer-input-ids").orElse(null),
+ getChildValue(xml, "transformer-attention-mask").orElse(null),
+ getChildValue(xml, "transformer-token-type-ids").orElse(null),
+ getChildValue(xml, "transformer-output").orElse(null),
+ getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null),
+ getChildValue(xml, "onnx-execution-mode").orElse(null),
+ getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null),
+ getChildValue(xml, "pooling-strategy").orElse(null),
+ getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null));
model.registerOnnxModelCost(cluster);
}
@Override
public void getConfig(BertBaseEmbedderConfig.Builder b) {
- b.transformerModel(modelRef).tokenizerVocab(vocabRef);
- if (maxTokens != null) b.transformerMaxTokens(maxTokens);
- if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
- if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
- if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
- if (transformerOutput != null) b.transformerOutput(transformerOutput);
- if (tranformerStartSequenceToken != null) b.transformerStartSequenceToken(tranformerStartSequenceToken);
- if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
- if (poolingStrategy != null) b.poolingStrategy(BertBaseEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy));
- if (onnxExecutionMode != null) b.onnxExecutionMode(BertBaseEmbedderConfig.OnnxExecutionMode.Enum.valueOf(onnxExecutionMode));
- if (onnxInteropThreads != null) b.onnxInterOpThreads(onnxInteropThreads);
- if (onnxIntraopThreads != null) b.onnxIntraOpThreads(onnxIntraopThreads);
- if (onnxGpuDevice != null) b.onnxGpuDevice(onnxGpuDevice);
+ b.transformerModel(onnxModelOptions.modelRef()).tokenizerVocab(onnxModelOptions.vocabRef());
+ if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens());
+ if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds());
+ if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask());
+ if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds());
+ if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput());
+ if (onnxModelOptions.transformerStartSequenceToken() != null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken());
+ if (onnxModelOptions.transformerEndSequenceToken() != null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken());
+ if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy()));
+ if (onnxModelOptions.executionMode() != null) b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(onnxModelOptions.executionMode()));
+ if (onnxModelOptions.interOpThreads() != null) b.onnxInterOpThreads(onnxModelOptions.interOpThreads());
+ if (onnxModelOptions.intraOpThreads() != null) b.onnxIntraOpThreads(onnxModelOptions.intraOpThreads());
+ if (onnxModelOptions.gpuDevice() != null) b.onnxGpuDevice(onnxModelOptions.gpuDevice().deviceNumber());
}
+
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index ed56579988d..5192c806797 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -1,5 +1,4 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
@@ -8,6 +7,7 @@ import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
@@ -16,46 +16,33 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
* @author bergum
*/
public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer {
- private final ModelReference modelRef;
- private final ModelReference vocabRef;
-
- private final Integer maxQueryTokens;
-
- private final Integer maxDocumentTokens;
- private final Integer transformerStartSequenceToken;
- private final Integer transformerEndSequenceToken;
- private final Integer transformerMaskToken;
- private final Integer maxTokens;
- private final String transformerInputIds;
- private final String transformerAttentionMask;
-
- private final String transformerOutput;
- private final String onnxExecutionMode;
- private final Integer onnxInteropThreads;
- private final Integer onnxIntraopThreads;
- private final Integer onnxGpuDevice;
+ private final OnnxModelOptions onnxModelOptions;
public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
- modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-model")
- .map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state));
- maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
- maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null);
- maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null);
- transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
- transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
- transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null);
- transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
- transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
- transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
- onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null);
- onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
- onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
- onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ this.onnxModelOptions = new OnnxModelOptions(
+ model.modelReference(),
+ Model.fromXml(state, xml, "tokenizer-model")
+ .map(Model::modelReference)
+ .orElseGet(() -> resolveDefaultVocab(model, state)),
+ getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "transformer-input-ids").orElse(null),
+ getChildValue(xml, "transformer-attention-mask").orElse(null),
+ getChildValue(xml, "transformer-token-type-ids").orElse(null),
+ getChildValue(xml, "transformer-output").orElse(null),
+ getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null),
+ getChildValue(xml, "onnx-execution-mode").orElse(null),
+ getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null),
+ getChildValue(xml, "pooling-strategy").orElse(null),
+ getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null));
model.registerOnnxModelCost(cluster);
}
@@ -69,20 +56,20 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
@Override
public void getConfig(ColBertEmbedderConfig.Builder b) {
- b.transformerModel(modelRef).tokenizerPath(vocabRef);
- if (maxTokens != null) b.transformerMaxTokens(maxTokens);
- if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
- if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
- if (transformerOutput != null) b.transformerOutput(transformerOutput);
- if (maxQueryTokens != null) b.maxQueryTokens(maxQueryTokens);
- if (maxDocumentTokens != null) b.maxDocumentTokens(maxDocumentTokens);
- if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken);
- if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
- if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken);
- if (onnxExecutionMode != null) b.transformerExecutionMode(
- ColBertEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode));
- if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
- if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
- if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
+ b.transformerModel(onnxModelOptions.modelRef()).tokenizerPath(onnxModelOptions.vocabRef());
+ if (onnxModelOptions.maxTokens() !=null) b.transformerMaxTokens(onnxModelOptions.maxTokens());
+ if (onnxModelOptions.transformerInputIds() !=null) b.transformerInputIds(onnxModelOptions.transformerInputIds());
+ if (onnxModelOptions.transformerAttentionMask() !=null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask());
+ if (onnxModelOptions.transformerOutput() !=null) b.transformerOutput(onnxModelOptions.transformerOutput());
+ if (onnxModelOptions.maxQueryTokens() !=null) b.maxQueryTokens(onnxModelOptions.maxQueryTokens());
+ if (onnxModelOptions.maxDocumentTokens() !=null) b.maxDocumentTokens(onnxModelOptions.maxDocumentTokens());
+ if (onnxModelOptions.transformerStartSequenceToken() !=null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken());
+ if (onnxModelOptions.transformerEndSequenceToken() !=null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken());
+ if (onnxModelOptions.transformerMaskToken() !=null) b.transformerMaskToken(onnxModelOptions.transformerMaskToken());
+ if (onnxModelOptions.executionMode() !=null) b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(onnxModelOptions.executionMode()));
+ if (onnxModelOptions.interOpThreads() !=null) b.transformerInterOpThreads(onnxModelOptions.interOpThreads());
+ if (onnxModelOptions.intraOpThreads() !=null) b.transformerIntraOpThreads(onnxModelOptions.intraOpThreads());
+ if (onnxModelOptions.gpuDevice() !=null) b.transformerGpuDevice(onnxModelOptions.gpuDevice().deviceNumber());
}
+
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index 31b86142445..5bf30174539 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -1,5 +1,4 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
@@ -8,6 +7,8 @@ import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy;
+import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
@@ -16,38 +17,29 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
* @author bjorncs
*/
public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer {
- private final ModelReference modelRef;
- private final ModelReference vocabRef;
- private final Integer maxTokens;
- private final String transformerInputIds;
- private final String transformerAttentionMask;
- private final String transformerTokenTypeIds;
- private final String transformerOutput;
- private final Boolean normalize;
- private final String onnxExecutionMode;
- private final Integer onnxInteropThreads;
- private final Integer onnxIntraopThreads;
- private final Integer onnxGpuDevice;
- private final String poolingStrategy;
+
+ private final OnnxModelOptions onnxModelOptions;
public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
- modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-model")
- .map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state));
- maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
- transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
- transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
- transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null);
- transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
- normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
- onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null);
- onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
- onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
- onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
- poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
+ var modelRef = model.modelReference();
+ this.onnxModelOptions = new OnnxModelOptions(
+ modelRef,
+ Model.fromXml(state, xml, "tokenizer-model")
+ .map(Model::modelReference)
+ .orElseGet(() -> resolveDefaultVocab(model, state)),
+ getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "transformer-input-ids").orElse(null),
+ getChildValue(xml, "transformer-attention-mask").orElse(null),
+ getChildValue(xml, "transformer-token-type-ids").orElse(null),
+ getChildValue(xml, "transformer-output").orElse(null),
+ getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null),
+ getChildValue(xml, "onnx-execution-mode").orElse(null),
+ getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null),
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null),
+ getChildValue(xml, "pooling-strategy").orElse(null));
model.registerOnnxModelCost(cluster);
}
@@ -61,18 +53,18 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
@Override
public void getConfig(HuggingFaceEmbedderConfig.Builder b) {
- b.transformerModel(modelRef).tokenizerPath(vocabRef);
- if (maxTokens != null) b.transformerMaxTokens(maxTokens);
- if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
- if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
- if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
- if (transformerOutput != null) b.transformerOutput(transformerOutput);
- if (normalize != null) b.normalize(normalize);
- if (onnxExecutionMode != null) b.transformerExecutionMode(
- HuggingFaceEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode));
- if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
- if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
- if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
- if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy));
+ b.transformerModel(onnxModelOptions.modelRef()).tokenizerPath(onnxModelOptions.vocabRef());
+ if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens());
+ if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds());
+ if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask());
+ if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds());
+ if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput());
+ if (onnxModelOptions.normalize() != null) b.normalize(onnxModelOptions.normalize());
+ if (onnxModelOptions.executionMode() != null) b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(onnxModelOptions.executionMode()));
+ if (onnxModelOptions.interOpThreads() != null) b.transformerInterOpThreads(onnxModelOptions.interOpThreads());
+ if (onnxModelOptions.intraOpThreads() != null) b.transformerIntraOpThreads(onnxModelOptions.intraOpThreads());
+ if (onnxModelOptions.gpuDevice() != null) b.transformerGpuDevice(onnxModelOptions.gpuDevice().deviceNumber());
+ if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy()));
}
+
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java
new file mode 100644
index 00000000000..6b573e870fa
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java
@@ -0,0 +1,84 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.config.ModelReference;
+
+/**
+ * @author hmusum
+ */
+public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens,
+ String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds,
+ String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads,
+ Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy,
+ Integer transformerStartSequenceToken, Integer transformerEndSequenceToken,
+ Integer maxQueryTokens, Integer maxDocumentTokens, Integer transformerMaskToken) {
+
+
+ public OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens,
+ String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds,
+ String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads,
+ Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy) {
+ this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds,
+ transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
+ null, null, null, null, null);
+ }
+
+ public OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens,
+ String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds,
+ String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads,
+ Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy,
+ Integer transformerStartSequenceToken, Integer transformerEndSequenceToken) {
+ this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds,
+ transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
+ transformerStartSequenceToken, transformerEndSequenceToken, null, null, null);
+ }
+
+ public static OnnxModelOptions empty() {
+ return new OnnxModelOptions(null, null, null, null, null, null, null, null, null, null, null, null, null,
+ null, null, null, null, null);
+ }
+
+ public OnnxModelOptions withExecutionMode(String executionMode) {
+ return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
+ transformerTokenTypeIds, transformerOutput, normalize, executionMode,
+ interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
+ transformerStartSequenceToken, transformerEndSequenceToken,
+ maxQueryTokens, maxDocumentTokens, transformerMaskToken);
+ }
+
+ public OnnxModelOptions withInteropThreads(Integer interopThreads) {
+ return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
+ transformerTokenTypeIds, transformerOutput, normalize, executionMode,
+ interopThreads, intraOpThreads, gpuDevice, poolingStrategy,
+ transformerStartSequenceToken, transformerEndSequenceToken,
+ maxQueryTokens, maxDocumentTokens, transformerMaskToken);
+ }
+
+ public OnnxModelOptions withIntraopThreads(Integer intraopThreads) {
+ return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
+ transformerTokenTypeIds, transformerOutput, normalize, executionMode,
+ interOpThreads, intraopThreads, gpuDevice, poolingStrategy,
+ transformerStartSequenceToken, transformerEndSequenceToken,
+ maxQueryTokens, maxDocumentTokens, transformerMaskToken);
+ }
+
+
+ public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) {
+ return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
+ transformerTokenTypeIds, transformerOutput, normalize, executionMode,
+ interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
+ transformerStartSequenceToken, transformerEndSequenceToken,
+ maxQueryTokens, maxDocumentTokens, transformerMaskToken);
+ }
+
+ public record GpuDevice(int deviceNumber, boolean required) {
+ public GpuDevice {
+ if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
+ }
+
+ public GpuDevice(int deviceNumber) {
+ this(deviceNumber, false);
+ }
+ }
+
+}