diff options
author | Harald Musum <musum@yahooinc.com> | 2023-11-16 09:27:02 +0100 |
---|---|---|
committer | Harald Musum <musum@yahooinc.com> | 2023-11-16 09:27:02 +0100 |
commit | ae5b333a133add3fc6d7ab6c056379c5b4f1564d (patch) | |
tree | 094bc2c57b5fa64a8c7850b793afcf3a8dd05750 /config-model | |
parent | 8c93087d6df4e995e61f17a92ab462e135608225 (diff) |
Add OnnxModelOptions and use it in embedders
Diffstat (limited to 'config-model')
5 files changed, 201 insertions, 157 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index f3f09150c1d..b22e737fd6a 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -3,6 +3,7 @@ package com.yahoo.schema; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.model.container.component.OnnxModelOptions; import com.yahoo.vespa.model.ml.OnnxModelInfo; import java.util.Collections; @@ -27,10 +28,7 @@ public class OnnxModel extends DistributableResource implements Cloneable { private final Set<String> initializers = new HashSet<>(); // Runtime options - private String statelessExecutionMode = null; - private Integer statelessInterOpThreads = null; - private Integer statelessIntraOpThreads = null; - private GpuDevice gpuDevice = null; + private OnnxModelOptions onnxModelOptions = OnnxModelOptions.empty(); public OnnxModel(String name) { super(name); @@ -133,50 +131,44 @@ public class OnnxModel extends DistributableResource implements Cloneable { public void setStatelessExecutionMode(String executionMode) { if ("parallel".equalsIgnoreCase(executionMode)) { - this.statelessExecutionMode = "parallel"; + onnxModelOptions = onnxModelOptions.withExecutionMode("parallel"); } else if ("sequential".equalsIgnoreCase(executionMode)) { - this.statelessExecutionMode = "sequential"; + onnxModelOptions = onnxModelOptions.withExecutionMode("sequential"); } } public Optional<String> getStatelessExecutionMode() { - return Optional.ofNullable(statelessExecutionMode); + return Optional.ofNullable(onnxModelOptions.executionMode()); } public void setStatelessInterOpThreads(int interOpThreads) { if (interOpThreads >= 0) { - this.statelessInterOpThreads = interOpThreads; + onnxModelOptions = onnxModelOptions.withInteropThreads(interOpThreads); } } public Optional<Integer> getStatelessInterOpThreads() { - return Optional.ofNullable(statelessInterOpThreads); + return Optional.ofNullable(onnxModelOptions.interOpThreads()); } public void setStatelessIntraOpThreads(int intraOpThreads) { if (intraOpThreads >= 0) { - this.statelessIntraOpThreads = intraOpThreads; + onnxModelOptions = onnxModelOptions.withIntraopThreads(intraOpThreads); } } public Optional<Integer> getStatelessIntraOpThreads() { - return Optional.ofNullable(statelessIntraOpThreads); + return Optional.ofNullable(onnxModelOptions.intraOpThreads()); } public void setGpuDevice(int deviceNumber, boolean required) { if (deviceNumber >= 0) { - this.gpuDevice = new GpuDevice(deviceNumber, required); + onnxModelOptions = onnxModelOptions.withGpuDevice(new OnnxModelOptions.GpuDevice(deviceNumber, required)); } } - public Optional<GpuDevice> getGpuDevice() { - return Optional.ofNullable(gpuDevice); - } - - public record GpuDevice(int deviceNumber, boolean required) { - public GpuDevice { - if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); - } + public Optional<OnnxModelOptions.GpuDevice> getGpuDevice() { + return Optional.ofNullable(onnxModelOptions.gpuDevice()); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index a644382625b..67edb40b4d3 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -1,13 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.model.container.component; -import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode; +import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -16,56 +16,45 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI */ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer { - private final ModelReference modelRef; - private final ModelReference vocabRef; - private final Integer maxTokens; - private final String transformerInputIds; - private final String transformerAttentionMask; - private final String transformerTokenTypeIds; - private final String transformerOutput; - private final Integer tranformerStartSequenceToken; - private final Integer transformerEndSequenceToken; - private final String poolingStrategy; - private final String onnxExecutionMode; - private final Integer onnxInteropThreads; - private final Integer onnxIntraopThreads; - private final Integer onnxGpuDevice; - + private final OnnxModelOptions onnxModelOptions; public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); - modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(); - maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); - transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); - transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); - transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null); - transformerOutput = getChildValue(xml, "transformer-output").orElse(null); - tranformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); - transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); - poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null); - onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); - onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); - onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + this.onnxModelOptions = new OnnxModelOptions( + model.modelReference(), + Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(), + getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null), + getChildValue(xml, "transformer-input-ids").orElse(null), + getChildValue(xml, "transformer-attention-mask").orElse(null), + getChildValue(xml, "transformer-token-type-ids").orElse(null), + getChildValue(xml, "transformer-output").orElse(null), + getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null), + getChildValue(xml, "onnx-execution-mode").orElse(null), + getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null), + getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null), + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null), + getChildValue(xml, "pooling-strategy").orElse(null), + getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null), + getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null)); model.registerOnnxModelCost(cluster); } @Override public void getConfig(BertBaseEmbedderConfig.Builder b) { - b.transformerModel(modelRef).tokenizerVocab(vocabRef); - if (maxTokens != null) b.transformerMaxTokens(maxTokens); - if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); - if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); - if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds); - if (transformerOutput != null) b.transformerOutput(transformerOutput); - if (tranformerStartSequenceToken != null) b.transformerStartSequenceToken(tranformerStartSequenceToken); - if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); - if (poolingStrategy != null) b.poolingStrategy(BertBaseEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy)); - if (onnxExecutionMode != null) b.onnxExecutionMode(BertBaseEmbedderConfig.OnnxExecutionMode.Enum.valueOf(onnxExecutionMode)); - if (onnxInteropThreads != null) b.onnxInterOpThreads(onnxInteropThreads); - if (onnxIntraopThreads != null) b.onnxIntraOpThreads(onnxIntraopThreads); - if (onnxGpuDevice != null) b.onnxGpuDevice(onnxGpuDevice); + b.transformerModel(onnxModelOptions.modelRef()).tokenizerVocab(onnxModelOptions.vocabRef()); + if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens()); + if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds()); + if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask()); + if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds()); + if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput()); + if (onnxModelOptions.transformerStartSequenceToken() != null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken()); + if (onnxModelOptions.transformerEndSequenceToken() != null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken()); + if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy())); + if (onnxModelOptions.executionMode() != null) b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(onnxModelOptions.executionMode())); + if (onnxModelOptions.interOpThreads() != null) b.onnxInterOpThreads(onnxModelOptions.interOpThreads()); + if (onnxModelOptions.intraOpThreads() != null) b.onnxIntraOpThreads(onnxModelOptions.intraOpThreads()); + if (onnxModelOptions.gpuDevice() != null) b.onnxGpuDevice(onnxModelOptions.gpuDevice().deviceNumber()); } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index ed56579988d..5192c806797 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -1,5 +1,4 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; @@ -8,6 +7,7 @@ import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -16,46 +16,33 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI * @author bergum */ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer { - private final ModelReference modelRef; - private final ModelReference vocabRef; - - private final Integer maxQueryTokens; - - private final Integer maxDocumentTokens; - private final Integer transformerStartSequenceToken; - private final Integer transformerEndSequenceToken; - private final Integer transformerMaskToken; - private final Integer maxTokens; - private final String transformerInputIds; - private final String transformerAttentionMask; - - private final String transformerOutput; - private final String onnxExecutionMode; - private final Integer onnxInteropThreads; - private final Integer onnxIntraopThreads; - private final Integer onnxGpuDevice; + private final OnnxModelOptions onnxModelOptions; public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); - modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model") - .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state)); - maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); - maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null); - maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null); - transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); - transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); - transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null); - transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); - transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); - transformerOutput = getChildValue(xml, "transformer-output").orElse(null); - onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null); - onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); - onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); - onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + this.onnxModelOptions = new OnnxModelOptions( + model.modelReference(), + Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)), + getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null), + getChildValue(xml, "transformer-input-ids").orElse(null), + getChildValue(xml, "transformer-attention-mask").orElse(null), + getChildValue(xml, "transformer-token-type-ids").orElse(null), + getChildValue(xml, "transformer-output").orElse(null), + getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null), + getChildValue(xml, "onnx-execution-mode").orElse(null), + getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null), + getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null), + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null), + getChildValue(xml, "pooling-strategy").orElse(null), + getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null), + getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null), + getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null), + getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null), + getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null)); model.registerOnnxModelCost(cluster); } @@ -69,20 +56,20 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo @Override public void getConfig(ColBertEmbedderConfig.Builder b) { - b.transformerModel(modelRef).tokenizerPath(vocabRef); - if (maxTokens != null) b.transformerMaxTokens(maxTokens); - if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); - if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); - if (transformerOutput != null) b.transformerOutput(transformerOutput); - if (maxQueryTokens != null) b.maxQueryTokens(maxQueryTokens); - if (maxDocumentTokens != null) b.maxDocumentTokens(maxDocumentTokens); - if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken); - if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); - if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken); - if (onnxExecutionMode != null) b.transformerExecutionMode( - ColBertEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode)); - if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); - if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); - if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); + b.transformerModel(onnxModelOptions.modelRef()).tokenizerPath(onnxModelOptions.vocabRef()); + if (onnxModelOptions.maxTokens() !=null) b.transformerMaxTokens(onnxModelOptions.maxTokens()); + if (onnxModelOptions.transformerInputIds() !=null) b.transformerInputIds(onnxModelOptions.transformerInputIds()); + if (onnxModelOptions.transformerAttentionMask() !=null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask()); + if (onnxModelOptions.transformerOutput() !=null) b.transformerOutput(onnxModelOptions.transformerOutput()); + if (onnxModelOptions.maxQueryTokens() !=null) b.maxQueryTokens(onnxModelOptions.maxQueryTokens()); + if (onnxModelOptions.maxDocumentTokens() !=null) b.maxDocumentTokens(onnxModelOptions.maxDocumentTokens()); + if (onnxModelOptions.transformerStartSequenceToken() !=null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken()); + if (onnxModelOptions.transformerEndSequenceToken() !=null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken()); + if (onnxModelOptions.transformerMaskToken() !=null) b.transformerMaskToken(onnxModelOptions.transformerMaskToken()); + if (onnxModelOptions.executionMode() !=null) b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(onnxModelOptions.executionMode())); + if (onnxModelOptions.interOpThreads() !=null) b.transformerInterOpThreads(onnxModelOptions.interOpThreads()); + if (onnxModelOptions.intraOpThreads() !=null) b.transformerIntraOpThreads(onnxModelOptions.intraOpThreads()); + if (onnxModelOptions.gpuDevice() !=null) b.transformerGpuDevice(onnxModelOptions.gpuDevice().deviceNumber()); } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index 31b86142445..5bf30174539 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -1,5 +1,4 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; @@ -8,6 +7,8 @@ import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy; +import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -16,38 +17,29 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI * @author bjorncs */ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer { - private final ModelReference modelRef; - private final ModelReference vocabRef; - private final Integer maxTokens; - private final String transformerInputIds; - private final String transformerAttentionMask; - private final String transformerTokenTypeIds; - private final String transformerOutput; - private final Boolean normalize; - private final String onnxExecutionMode; - private final Integer onnxInteropThreads; - private final Integer onnxIntraopThreads; - private final Integer onnxGpuDevice; - private final String poolingStrategy; + + private final OnnxModelOptions onnxModelOptions; public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); - modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model") - .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state)); - maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); - transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); - transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); - transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null); - transformerOutput = getChildValue(xml, "transformer-output").orElse(null); - normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); - onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null); - onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); - onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); - onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); - poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); + var modelRef = model.modelReference(); + this.onnxModelOptions = new OnnxModelOptions( + modelRef, + Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)), + getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null), + getChildValue(xml, "transformer-input-ids").orElse(null), + getChildValue(xml, "transformer-attention-mask").orElse(null), + getChildValue(xml, "transformer-token-type-ids").orElse(null), + getChildValue(xml, "transformer-output").orElse(null), + getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null), + getChildValue(xml, "onnx-execution-mode").orElse(null), + getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null), + getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null), + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null), + getChildValue(xml, "pooling-strategy").orElse(null)); model.registerOnnxModelCost(cluster); } @@ -61,18 +53,18 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm @Override public void getConfig(HuggingFaceEmbedderConfig.Builder b) { - b.transformerModel(modelRef).tokenizerPath(vocabRef); - if (maxTokens != null) b.transformerMaxTokens(maxTokens); - if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); - if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); - if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds); - if (transformerOutput != null) b.transformerOutput(transformerOutput); - if (normalize != null) b.normalize(normalize); - if (onnxExecutionMode != null) b.transformerExecutionMode( - HuggingFaceEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode)); - if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); - if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); - if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); - if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy)); + b.transformerModel(onnxModelOptions.modelRef()).tokenizerPath(onnxModelOptions.vocabRef()); + if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens()); + if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds()); + if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask()); + if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds()); + if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput()); + if (onnxModelOptions.normalize() != null) b.normalize(onnxModelOptions.normalize()); + if (onnxModelOptions.executionMode() != null) b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(onnxModelOptions.executionMode())); + if (onnxModelOptions.interOpThreads() != null) b.transformerInterOpThreads(onnxModelOptions.interOpThreads()); + if (onnxModelOptions.intraOpThreads() != null) b.transformerIntraOpThreads(onnxModelOptions.intraOpThreads()); + if (onnxModelOptions.gpuDevice() != null) b.transformerGpuDevice(onnxModelOptions.gpuDevice().deviceNumber()); + if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy())); } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java new file mode 100644 index 00000000000..6b573e870fa --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java @@ -0,0 +1,84 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; + +/** + * @author hmusum + */ +public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens, + String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds, + String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads, + Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy, + Integer transformerStartSequenceToken, Integer transformerEndSequenceToken, + Integer maxQueryTokens, Integer maxDocumentTokens, Integer transformerMaskToken) { + + + public OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens, + String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds, + String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads, + Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy) { + this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds, + transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, + null, null, null, null, null); + } + + public OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens, + String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds, + String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads, + Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy, + Integer transformerStartSequenceToken, Integer transformerEndSequenceToken) { + this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds, + transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, + transformerStartSequenceToken, transformerEndSequenceToken, null, null, null); + } + + public static OnnxModelOptions empty() { + return new OnnxModelOptions(null, null, null, null, null, null, null, null, null, null, null, null, null, + null, null, null, null, null); + } + + public OnnxModelOptions withExecutionMode(String executionMode) { + return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, + transformerTokenTypeIds, transformerOutput, normalize, executionMode, + interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, + transformerStartSequenceToken, transformerEndSequenceToken, + maxQueryTokens, maxDocumentTokens, transformerMaskToken); + } + + public OnnxModelOptions withInteropThreads(Integer interopThreads) { + return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, + transformerTokenTypeIds, transformerOutput, normalize, executionMode, + interopThreads, intraOpThreads, gpuDevice, poolingStrategy, + transformerStartSequenceToken, transformerEndSequenceToken, + maxQueryTokens, maxDocumentTokens, transformerMaskToken); + } + + public OnnxModelOptions withIntraopThreads(Integer intraopThreads) { + return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, + transformerTokenTypeIds, transformerOutput, normalize, executionMode, + interOpThreads, intraopThreads, gpuDevice, poolingStrategy, + transformerStartSequenceToken, transformerEndSequenceToken, + maxQueryTokens, maxDocumentTokens, transformerMaskToken); + } + + + public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) { + return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, + transformerTokenTypeIds, transformerOutput, normalize, executionMode, + interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, + transformerStartSequenceToken, transformerEndSequenceToken, + maxQueryTokens, maxDocumentTokens, transformerMaskToken); + } + + public record GpuDevice(int deviceNumber, boolean required) { + public GpuDevice { + if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); + } + + public GpuDevice(int deviceNumber) { + this(deviceNumber, false); + } + } + +} |