diff options
author | Jo Kristian Bergum <bergum@yahoo-inc.com> | 2024-01-10 13:28:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-10 13:28:47 +0100 |
commit | c4e33003e5ce3f385951c107714ede8556ef8083 (patch) | |
tree | cd4b84ac9e3840384ad0a7c2b6c85de0a01a7102 /config-model/src/main/java/com/yahoo | |
parent | 949cede5ec0375c03dacdbb141f04e471aac8099 (diff) | |
parent | 2f3a69daf2f212aaa3ed29c89407d4af95b65138 (diff) |
Merge pull request #29826 from vespa-engine/jobergum/colbert-handle-multilingual-tokenizers
colbert handle multilingual tokenizers better
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index abca3290a31..aa8f97784e1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -33,10 +33,16 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo private final Integer transformerStartSequenceToken; private final Integer transformerEndSequenceToken; private final Integer transformerMaskToken; + + private final Integer transformerPadToken; private final Integer maxTokens; private final String transformerInputIds; private final String transformerAttentionMask; + private final Integer queryTokenId; + + private final Integer documentTokenId; + private final String transformerOutput; public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { @@ -55,6 +61,9 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null); + transformerPadToken = getChildValue(xml, "transformer-pad-token").map(Integer::parseInt).orElse(null); + queryTokenId = getChildValue(xml, "query-token-id").map(Integer::parseInt).orElse(null); + documentTokenId = getChildValue(xml, "document-token-id").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); transformerOutput = getChildValue(xml, "transformer-output").orElse(null); @@ -73,10 +82,13 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken); if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken); - onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); - onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); - onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); - onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); + if (transformerPadToken != null) b.transformerPadToken(transformerPadToken); + if (queryTokenId != null) b.queryTokenId(queryTokenId); + if (documentTokenId != null) b.documentTokenId(documentTokenId); + onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); + onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); + onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); + onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); } } |