diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-01-06 10:54:58 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-01-06 10:54:58 +0100 |
commit | 18ae21bce56e018cef2c17d03e63617530af59ae (patch) | |
tree | 3c1dcee63395fee2e476be9ce33e2437262b00d7 /config-model/src/main/java/com/yahoo/vespa/model/container/component | |
parent | e4da75db4556a3cd72b034c4406027f9bba73918 (diff) |
handle multilingual models better
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index d22e6afc3d1..fcd37150fe4 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -29,10 +29,16 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo private final Integer transformerStartSequenceToken; private final Integer transformerEndSequenceToken; private final Integer transformerMaskToken; + + private final Integer transformerPadToken; private final Integer maxTokens; private final String transformerInputIds; private final String transformerAttentionMask; + private final Integer queryTokenId; + + private final Integer documentTokenId; + private final String transformerOutput; public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { @@ -53,6 +59,9 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null); + transformerPadToken = getChildValue(xml, "transformer-pad-token").map(Integer::parseInt).orElse(null); + queryTokenId = getChildValue(xml, "query-token-id").map(Integer::parseInt).orElse(null); + documentTokenId = getChildValue(xml, "document-token-id").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); transformerOutput = getChildValue(xml, "transformer-output").orElse(null); @@ -79,10 +88,13 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken); if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken); - onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); - onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); - onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); - onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); + if (transformerPadToken != null) b.transformerPadToken(transformerPadToken); + if (queryTokenId != null) b.queryTokenId(queryTokenId); + if (documentTokenId != null) b.documentTokenId(documentTokenId); + onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); + onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); + onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); + onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); } } |