aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahoo-inc.com>2024-01-10 13:28:47 +0100
committerGitHub <noreply@github.com>2024-01-10 13:28:47 +0100
commitc4e33003e5ce3f385951c107714ede8556ef8083 (patch)
treecd4b84ac9e3840384ad0a7c2b6c85de0a01a7102 /config-model/src/main/java/com/yahoo
parent949cede5ec0375c03dacdbb141f04e471aac8099 (diff)
parent2f3a69daf2f212aaa3ed29c89407d4af95b65138 (diff)
Merge pull request #29826 from vespa-engine/jobergum/colbert-handle-multilingual-tokenizers
colbert handle multilingual tokenizers better
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java20
1 files changed, 16 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index abca3290a31..aa8f97784e1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -33,10 +33,16 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
private final Integer transformerStartSequenceToken;
private final Integer transformerEndSequenceToken;
private final Integer transformerMaskToken;
+
+ private final Integer transformerPadToken;
private final Integer maxTokens;
private final String transformerInputIds;
private final String transformerAttentionMask;
+ private final Integer queryTokenId;
+
+ private final Integer documentTokenId;
+
private final String transformerOutput;
public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
@@ -55,6 +61,9 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null);
+ transformerPadToken = getChildValue(xml, "transformer-pad-token").map(Integer::parseInt).orElse(null);
+ queryTokenId = getChildValue(xml, "query-token-id").map(Integer::parseInt).orElse(null);
+ documentTokenId = getChildValue(xml, "document-token-id").map(Integer::parseInt).orElse(null);
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
@@ -73,10 +82,13 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken);
if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken);
- onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
- onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
- onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
- onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
+ if (transformerPadToken != null) b.transformerPadToken(transformerPadToken);
+ if (queryTokenId != null) b.queryTokenId(queryTokenId);
+ if (documentTokenId != null) b.documentTokenId(documentTokenId);
+ onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
+ onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
+ onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
+ onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
}
}