aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/container/component
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-01-06 10:54:58 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2024-01-06 10:54:58 +0100
commit18ae21bce56e018cef2c17d03e63617530af59ae (patch)
tree3c1dcee63395fee2e476be9ce33e2437262b00d7 /config-model/src/main/java/com/yahoo/vespa/model/container/component
parente4da75db4556a3cd72b034c4406027f9bba73918 (diff)
handle multilingual models better
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java20
1 files changed, 16 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index d22e6afc3d1..fcd37150fe4 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -29,10 +29,16 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
private final Integer transformerStartSequenceToken;
private final Integer transformerEndSequenceToken;
private final Integer transformerMaskToken;
+
+ private final Integer transformerPadToken;
private final Integer maxTokens;
private final String transformerInputIds;
private final String transformerAttentionMask;
+ private final Integer queryTokenId;
+
+ private final Integer documentTokenId;
+
private final String transformerOutput;
public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
@@ -53,6 +59,9 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null);
+ transformerPadToken = getChildValue(xml, "transformer-pad-token").map(Integer::parseInt).orElse(null);
+ queryTokenId = getChildValue(xml, "query-token-id").map(Integer::parseInt).orElse(null);
+ documentTokenId = getChildValue(xml, "document-token-id").map(Integer::parseInt).orElse(null);
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
@@ -79,10 +88,13 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken);
if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken);
- onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
- onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
- onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
- onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
+ if (transformerPadToken != null) b.transformerPadToken(transformerPadToken);
+ if (queryTokenId != null) b.queryTokenId(queryTokenId);
+ if (documentTokenId != null) b.documentTokenId(documentTokenId);
+ onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
+ onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
+ onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
+ onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
}
}