diff options
author | Harald Musum <musum@yahooinc.com> | 2023-11-16 11:19:28 +0100 |
---|---|---|
committer | Harald Musum <musum@yahooinc.com> | 2023-11-16 11:19:28 +0100 |
commit | 2d1ba9b48bdad4c1af7058a37821e447c89bb420 (patch) | |
tree | e2d912a1a0a174834196e967b420440f2d7f32c7 /config-model/src/main/java/com | |
parent | d20e4909d59116b231c4b25fba35f22e673e407b (diff) |
Include only options that control how an Onnx model is loaded in OnnxModelOptions
Diffstat (limited to 'config-model/src/main/java/com')
5 files changed, 114 insertions, 141 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 0d2c695280e..867ffdb3960 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -143,7 +143,7 @@ public class OnnxModel extends DistributableResource implements Cloneable { public void setStatelessInterOpThreads(int interOpThreads) { if (interOpThreads >= 0) { - onnxModelOptions = onnxModelOptions.withInteropThreads(interOpThreads); + onnxModelOptions = onnxModelOptions.withInterOpThreads(interOpThreads); } } @@ -153,7 +153,7 @@ public class OnnxModel extends DistributableResource implements Cloneable { public void setStatelessIntraOpThreads(int intraOpThreads) { if (intraOpThreads >= 0) { - onnxModelOptions = onnxModelOptions.withIntraopThreads(intraOpThreads); + onnxModelOptions = onnxModelOptions.withIntraOpThreads(intraOpThreads); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index 78177fd7d57..ea3caadc23a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -1,13 +1,12 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container.component; +import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; -import java.util.Optional; - import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode; import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy; import static com.yahoo.text.XML.getChildValue; @@ -19,40 +18,49 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer { private final OnnxModelOptions onnxModelOptions; + private final ModelReference modelRef; + private final ModelReference vocabRef; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + private final String transformerTokenTypeIds; + private final String transformerOutput; + private final Integer transformerStartSequenceToken; + private final Integer transformerEndSequenceToken; + private final String poolingStrategy; public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( - Optional.of(model.modelReference()), - Optional.of(Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference()), - getChildValue(xml, "max-tokens").map(Integer::parseInt), - getChildValue(xml, "transformer-input-ids"), - getChildValue(xml, "transformer-attention-mask"), - getChildValue(xml, "transformer-token-type-ids"), - getChildValue(xml, "transformer-output"), - getChildValue(xml, "normalize").map(Boolean::parseBoolean), getChildValue(xml, "onnx-execution-mode"), getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), - getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new), - getChildValue(xml, "pooling-strategy"), - getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt), - getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt)); + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(); + maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); + transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null); + transformerOutput = getChildValue(xml, "transformer-output").orElse(null); + transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); + transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); + poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); model.registerOnnxModelCost(cluster); } @Override public void getConfig(BertBaseEmbedderConfig.Builder b) { - b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerVocab(onnxModelOptions.vocabRef().get()); - onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens); - onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds); - onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask); - onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds); - onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput); - onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken); - onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken); - onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value))); + b.transformerModel(modelRef).tokenizerVocab(vocabRef); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken); + if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); + if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy)); onnxModelOptions.executionMode().ifPresent(value -> b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(value))); onnxModelOptions.interOpThreads().ifPresent(b::onnxInterOpThreads); onnxModelOptions.intraOpThreads().ifPresent(b::onnxIntraOpThreads); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index 33753094f69..cbae50b400c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -7,8 +7,6 @@ import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; -import java.util.Optional; - import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -20,31 +18,43 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer { private final OnnxModelOptions onnxModelOptions; + private final ModelReference modelRef; + private final ModelReference vocabRef; + + private final Integer maxQueryTokens; + + private final Integer maxDocumentTokens; + + private final Integer transformerStartSequenceToken; + private final Integer transformerEndSequenceToken; + private final Integer transformerMaskToken; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + + private final String transformerOutput; public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( - Optional.of(model.modelReference()), - Optional.of(Model.fromXml(state, xml, "tokenizer-model") - .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state))), - getChildValue(xml, "max-tokens").map(Integer::parseInt), - getChildValue(xml, "transformer-input-ids"), - getChildValue(xml, "transformer-attention-mask"), - getChildValue(xml, "transformer-token-type-ids"), - getChildValue(xml, "transformer-output"), - getChildValue(xml, "normalize").map(Boolean::parseBoolean), getChildValue(xml, "onnx-execution-mode"), getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), - getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new), - getChildValue(xml, "pooling-strategy"), - getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt), - getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt), - getChildValue(xml, "max-query-tokens").map(Integer::parseInt), - getChildValue(xml, "max-document-tokens").map(Integer::parseInt), - getChildValue(xml, "transformer-mask-token").map(Integer::parseInt)); + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)); + maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null); + maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null); + transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); + transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); + transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null); + transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); + transformerOutput = getChildValue(xml, "transformer-output").orElse(null); model.registerOnnxModelCost(cluster); } @@ -58,16 +68,16 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo @Override public void getConfig(ColBertEmbedderConfig.Builder b) { - b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerPath(onnxModelOptions.vocabRef().get()); - onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens); - onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds); - onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask); - onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput); - onnxModelOptions.maxQueryTokens().ifPresent(b::maxQueryTokens); - onnxModelOptions.maxDocumentTokens().ifPresent(b::maxDocumentTokens); - onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken); - onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken); - onnxModelOptions.transformerMaskToken().ifPresent(b::transformerMaskToken); + b.transformerModel(modelRef).tokenizerPath(vocabRef); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (maxQueryTokens != null) b.maxQueryTokens(maxQueryTokens); + if (maxDocumentTokens != null) b.maxDocumentTokens(maxDocumentTokens); + if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken); + if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); + if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken); onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index af6febeb1e5..d1bd0dce000 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -7,8 +7,6 @@ import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; -import java.util.Optional; - import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; @@ -21,27 +19,35 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer { private final OnnxModelOptions onnxModelOptions; + private final ModelReference modelRef; + private final ModelReference vocabRef; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + private final String transformerTokenTypeIds; + private final String transformerOutput; + private final Boolean normalize; + private final String poolingStrategy; public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); - var modelRef = model.modelReference(); this.onnxModelOptions = new OnnxModelOptions( - Optional.of(modelRef), - Optional.of(Model.fromXml(state, xml, "tokenizer-model") - .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state))), - getChildValue(xml, "max-tokens").map(Integer::parseInt), - getChildValue(xml, "transformer-input-ids"), - getChildValue(xml, "transformer-attention-mask"), - getChildValue(xml, "transformer-token-type-ids"), - getChildValue(xml, "transformer-output"), - getChildValue(xml, "normalize").map(Boolean::parseBoolean), getChildValue(xml, "onnx-execution-mode"), getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), - getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new), - getChildValue(xml, "pooling-strategy")); + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)); + maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); + transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null); + transformerOutput = getChildValue(xml, "transformer-output").orElse(null); + normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); + poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); model.registerOnnxModelCost(cluster); } @@ -55,18 +61,18 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm @Override public void getConfig(HuggingFaceEmbedderConfig.Builder b) { - b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerPath(onnxModelOptions.vocabRef().get()); - onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens); - onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds); - onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask); - onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds); - onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput); - onnxModelOptions.normalize().ifPresent(b::normalize); + b.transformerModel(modelRef).tokenizerPath(vocabRef); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (normalize != null) b.normalize(normalize); + if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy)); onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); - onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value))); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java index 0efb8a71fe4..6347f0dc427 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java @@ -1,86 +1,35 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container.component; -import com.yahoo.config.ModelReference; - import java.util.Optional; /** + * Onnx model options that are relevant when deciding if an Onnx model needs to be reloaded. If any of the + * values in this class change, reload is needed. + * * @author hmusum */ -public record OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef, - Optional<Integer> maxTokens, Optional<String> transformerInputIds, - Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds, - Optional<String> transformerOutput, Optional<Boolean> normalize, - Optional<String> executionMode, Optional<Integer> interOpThreads, - Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice, - Optional<String> poolingStrategy, Optional<Integer> transformerStartSequenceToken, - Optional<Integer> transformerEndSequenceToken, Optional<Integer> maxQueryTokens, - Optional<Integer> maxDocumentTokens, Optional<Integer> transformerMaskToken) { - - - public OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef, - Optional<Integer> maxTokens, Optional<String> transformerInputIds, - Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds, - Optional<String> transformerOutput, Optional<Boolean> normalize, - Optional<String> executionMode, Optional<Integer> interOpThreads, - Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice, - Optional<String> poolingStrategy) { - this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds, - transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, - Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - } +public record OnnxModelOptions(Optional<String> executionMode, Optional<Integer> interOpThreads, + Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice) { - public OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef, - Optional<Integer> maxTokens, Optional<String> transformerInputIds, - Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds, - Optional<String> transformerOutput, Optional<Boolean> normalize, - Optional<String> executionMode, Optional<Integer> interOpThreads, - Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice, - Optional<String> poolingStrategy, Optional<Integer> transformerStartSequenceToken, - Optional<Integer> transformerEndSequenceToken) { - this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds, - transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, - transformerStartSequenceToken, transformerEndSequenceToken, Optional.empty(), Optional.empty(), Optional.empty()); - } - - public static OnnxModelOptions empty() { - return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), - Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), - Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + public static OnnxModelOptions empty() { + return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); } public OnnxModelOptions withExecutionMode(String executionMode) { - return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, - transformerTokenTypeIds, transformerOutput, normalize, Optional.ofNullable(executionMode), - interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, - transformerStartSequenceToken, transformerEndSequenceToken, - maxQueryTokens, maxDocumentTokens, transformerMaskToken); + return new OnnxModelOptions(Optional.ofNullable(executionMode), interOpThreads, intraOpThreads, gpuDevice); } - public OnnxModelOptions withInteropThreads(Integer interopThreads) { - return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, - transformerTokenTypeIds, transformerOutput, normalize, executionMode, - Optional.ofNullable(interopThreads), intraOpThreads, gpuDevice, poolingStrategy, - transformerStartSequenceToken, transformerEndSequenceToken, - maxQueryTokens, maxDocumentTokens, transformerMaskToken); + public OnnxModelOptions withInterOpThreads(Integer interOpThreads) { + return new OnnxModelOptions(executionMode, Optional.ofNullable(interOpThreads), intraOpThreads, gpuDevice); } - public OnnxModelOptions withIntraopThreads(Integer intraopThreads) { - return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, - transformerTokenTypeIds, transformerOutput, normalize, executionMode, - interOpThreads, Optional.ofNullable(intraopThreads), gpuDevice, poolingStrategy, - transformerStartSequenceToken, transformerEndSequenceToken, - maxQueryTokens, maxDocumentTokens, transformerMaskToken); + public OnnxModelOptions withIntraOpThreads(Integer intraOpThreads) { + return new OnnxModelOptions(executionMode, interOpThreads, Optional.ofNullable(intraOpThreads), gpuDevice); } - public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) { - return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, - transformerTokenTypeIds, transformerOutput, normalize, executionMode, - interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice), poolingStrategy, - transformerStartSequenceToken, transformerEndSequenceToken, - maxQueryTokens, maxDocumentTokens, transformerMaskToken); + return new OnnxModelOptions(executionMode, interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice)); } public record GpuDevice(int deviceNumber, boolean required) { |