diff options
author | Harald Musum <musum@yahooinc.com> | 2023-11-16 10:45:20 +0100 |
---|---|---|
committer | Harald Musum <musum@yahooinc.com> | 2023-11-16 10:45:20 +0100 |
commit | d20e4909d59116b231c4b25fba35f22e673e407b (patch) | |
tree | 30447ea12070c1c4fe7f209f0f1fb8fb7208cb18 | |
parent | ae5b333a133add3fc6d7ab6c056379c5b4f1564d (diff) |
Use Optionals in OnnxModelOptions
5 files changed, 132 insertions, 114 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index b22e737fd6a..0d2c695280e 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -138,7 +138,7 @@ public class OnnxModel extends DistributableResource implements Cloneable { } public Optional<String> getStatelessExecutionMode() { - return Optional.ofNullable(onnxModelOptions.executionMode()); + return onnxModelOptions.executionMode(); } public void setStatelessInterOpThreads(int interOpThreads) { @@ -148,7 +148,7 @@ public class OnnxModel extends DistributableResource implements Cloneable { } public Optional<Integer> getStatelessInterOpThreads() { - return Optional.ofNullable(onnxModelOptions.interOpThreads()); + return onnxModelOptions.interOpThreads(); } public void setStatelessIntraOpThreads(int intraOpThreads) { @@ -158,7 +158,7 @@ public class OnnxModel extends DistributableResource implements Cloneable { } public Optional<Integer> getStatelessIntraOpThreads() { - return Optional.ofNullable(onnxModelOptions.intraOpThreads()); + return onnxModelOptions.intraOpThreads(); } public void setGpuDevice(int deviceNumber, boolean required) { @@ -168,7 +168,7 @@ public class OnnxModel extends DistributableResource implements Cloneable { } public Optional<OnnxModelOptions.GpuDevice> getGpuDevice() { - return Optional.ofNullable(onnxModelOptions.gpuDevice()); + return onnxModelOptions.gpuDevice(); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index 67edb40b4d3..78177fd7d57 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -6,6 +6,8 @@ import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import java.util.Optional; + import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode; import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy; import static com.yahoo.text.XML.getChildValue; @@ -22,39 +24,39 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( - model.modelReference(), - Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(), - getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null), - getChildValue(xml, "transformer-input-ids").orElse(null), - getChildValue(xml, "transformer-attention-mask").orElse(null), - getChildValue(xml, "transformer-token-type-ids").orElse(null), - getChildValue(xml, "transformer-output").orElse(null), - getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null), - getChildValue(xml, "onnx-execution-mode").orElse(null), - getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null), - getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null), - getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null), - getChildValue(xml, "pooling-strategy").orElse(null), - getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null), - getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null)); + Optional.of(model.modelReference()), + Optional.of(Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference()), + getChildValue(xml, "max-tokens").map(Integer::parseInt), + getChildValue(xml, "transformer-input-ids"), + getChildValue(xml, "transformer-attention-mask"), + getChildValue(xml, "transformer-token-type-ids"), + getChildValue(xml, "transformer-output"), + getChildValue(xml, "normalize").map(Boolean::parseBoolean), + getChildValue(xml, "onnx-execution-mode"), + getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new), + getChildValue(xml, "pooling-strategy"), + getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt), + getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt)); model.registerOnnxModelCost(cluster); } @Override public void getConfig(BertBaseEmbedderConfig.Builder b) { - b.transformerModel(onnxModelOptions.modelRef()).tokenizerVocab(onnxModelOptions.vocabRef()); - if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens()); - if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds()); - if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask()); - if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds()); - if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput()); - if (onnxModelOptions.transformerStartSequenceToken() != null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken()); - if (onnxModelOptions.transformerEndSequenceToken() != null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken()); - if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy())); - if (onnxModelOptions.executionMode() != null) b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(onnxModelOptions.executionMode())); - if (onnxModelOptions.interOpThreads() != null) b.onnxInterOpThreads(onnxModelOptions.interOpThreads()); - if (onnxModelOptions.intraOpThreads() != null) b.onnxIntraOpThreads(onnxModelOptions.intraOpThreads()); - if (onnxModelOptions.gpuDevice() != null) b.onnxGpuDevice(onnxModelOptions.gpuDevice().deviceNumber()); + b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerVocab(onnxModelOptions.vocabRef().get()); + onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens); + onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds); + onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask); + onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds); + onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput); + onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken); + onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken); + onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value))); + onnxModelOptions.executionMode().ifPresent(value -> b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(value))); + onnxModelOptions.interOpThreads().ifPresent(b::onnxInterOpThreads); + onnxModelOptions.intraOpThreads().ifPresent(b::onnxIntraOpThreads); + onnxModelOptions.gpuDevice().ifPresent(value -> b.onnxGpuDevice(value.deviceNumber())); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index 5192c806797..33753094f69 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -7,6 +7,8 @@ import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import java.util.Optional; + import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -23,26 +25,26 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( - model.modelReference(), - Model.fromXml(state, xml, "tokenizer-model") + Optional.of(model.modelReference()), + Optional.of(Model.fromXml(state, xml, "tokenizer-model") .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state)), - getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null), - getChildValue(xml, "transformer-input-ids").orElse(null), - getChildValue(xml, "transformer-attention-mask").orElse(null), - getChildValue(xml, "transformer-token-type-ids").orElse(null), - getChildValue(xml, "transformer-output").orElse(null), - getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null), - getChildValue(xml, "onnx-execution-mode").orElse(null), - getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null), - getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null), - getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null), - getChildValue(xml, "pooling-strategy").orElse(null), - getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null), - getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null), - getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null), - getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null), - getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null)); + .orElseGet(() -> resolveDefaultVocab(model, state))), + getChildValue(xml, "max-tokens").map(Integer::parseInt), + getChildValue(xml, "transformer-input-ids"), + getChildValue(xml, "transformer-attention-mask"), + getChildValue(xml, "transformer-token-type-ids"), + getChildValue(xml, "transformer-output"), + getChildValue(xml, "normalize").map(Boolean::parseBoolean), + getChildValue(xml, "onnx-execution-mode"), + getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new), + getChildValue(xml, "pooling-strategy"), + getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt), + getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt), + getChildValue(xml, "max-query-tokens").map(Integer::parseInt), + getChildValue(xml, "max-document-tokens").map(Integer::parseInt), + getChildValue(xml, "transformer-mask-token").map(Integer::parseInt)); model.registerOnnxModelCost(cluster); } @@ -56,20 +58,20 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo @Override public void getConfig(ColBertEmbedderConfig.Builder b) { - b.transformerModel(onnxModelOptions.modelRef()).tokenizerPath(onnxModelOptions.vocabRef()); - if (onnxModelOptions.maxTokens() !=null) b.transformerMaxTokens(onnxModelOptions.maxTokens()); - if (onnxModelOptions.transformerInputIds() !=null) b.transformerInputIds(onnxModelOptions.transformerInputIds()); - if (onnxModelOptions.transformerAttentionMask() !=null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask()); - if (onnxModelOptions.transformerOutput() !=null) b.transformerOutput(onnxModelOptions.transformerOutput()); - if (onnxModelOptions.maxQueryTokens() !=null) b.maxQueryTokens(onnxModelOptions.maxQueryTokens()); - if (onnxModelOptions.maxDocumentTokens() !=null) b.maxDocumentTokens(onnxModelOptions.maxDocumentTokens()); - if (onnxModelOptions.transformerStartSequenceToken() !=null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken()); - if (onnxModelOptions.transformerEndSequenceToken() !=null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken()); - if (onnxModelOptions.transformerMaskToken() !=null) b.transformerMaskToken(onnxModelOptions.transformerMaskToken()); - if (onnxModelOptions.executionMode() !=null) b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(onnxModelOptions.executionMode())); - if (onnxModelOptions.interOpThreads() !=null) b.transformerInterOpThreads(onnxModelOptions.interOpThreads()); - if (onnxModelOptions.intraOpThreads() !=null) b.transformerIntraOpThreads(onnxModelOptions.intraOpThreads()); - if (onnxModelOptions.gpuDevice() !=null) b.transformerGpuDevice(onnxModelOptions.gpuDevice().deviceNumber()); + b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerPath(onnxModelOptions.vocabRef().get()); + onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens); + onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds); + onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask); + onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput); + onnxModelOptions.maxQueryTokens().ifPresent(b::maxQueryTokens); + onnxModelOptions.maxDocumentTokens().ifPresent(b::maxDocumentTokens); + onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken); + onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken); + onnxModelOptions.transformerMaskToken().ifPresent(b::transformerMaskToken); + onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); + onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); + onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); + onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index 5bf30174539..af6febeb1e5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -7,6 +7,8 @@ import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import java.util.Optional; + import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; @@ -25,21 +27,21 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); var modelRef = model.modelReference(); this.onnxModelOptions = new OnnxModelOptions( - modelRef, - Model.fromXml(state, xml, "tokenizer-model") + Optional.of(modelRef), + Optional.of(Model.fromXml(state, xml, "tokenizer-model") .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state)), - getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null), - getChildValue(xml, "transformer-input-ids").orElse(null), - getChildValue(xml, "transformer-attention-mask").orElse(null), - getChildValue(xml, "transformer-token-type-ids").orElse(null), - getChildValue(xml, "transformer-output").orElse(null), - getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null), - getChildValue(xml, "onnx-execution-mode").orElse(null), - getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null), - getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null), - getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null), - getChildValue(xml, "pooling-strategy").orElse(null)); + .orElseGet(() -> resolveDefaultVocab(model, state))), + getChildValue(xml, "max-tokens").map(Integer::parseInt), + getChildValue(xml, "transformer-input-ids"), + getChildValue(xml, "transformer-attention-mask"), + getChildValue(xml, "transformer-token-type-ids"), + getChildValue(xml, "transformer-output"), + getChildValue(xml, "normalize").map(Boolean::parseBoolean), + getChildValue(xml, "onnx-execution-mode"), + getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new), + getChildValue(xml, "pooling-strategy")); model.registerOnnxModelCost(cluster); } @@ -53,18 +55,18 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm @Override public void getConfig(HuggingFaceEmbedderConfig.Builder b) { - b.transformerModel(onnxModelOptions.modelRef()).tokenizerPath(onnxModelOptions.vocabRef()); - if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens()); - if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds()); - if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask()); - if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds()); - if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput()); - if (onnxModelOptions.normalize() != null) b.normalize(onnxModelOptions.normalize()); - if (onnxModelOptions.executionMode() != null) b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(onnxModelOptions.executionMode())); - if (onnxModelOptions.interOpThreads() != null) b.transformerInterOpThreads(onnxModelOptions.interOpThreads()); - if (onnxModelOptions.intraOpThreads() != null) b.transformerIntraOpThreads(onnxModelOptions.intraOpThreads()); - if (onnxModelOptions.gpuDevice() != null) b.transformerGpuDevice(onnxModelOptions.gpuDevice().deviceNumber()); - if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy())); + b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerPath(onnxModelOptions.vocabRef().get()); + onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens); + onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds); + onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask); + onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds); + onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput); + onnxModelOptions.normalize().ifPresent(b::normalize); + onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); + onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); + onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); + onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); + onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value))); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java index 6b573e870fa..0efb8a71fe4 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java @@ -3,44 +3,56 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import java.util.Optional; + /** * @author hmusum */ -public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens, - String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds, - String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads, - Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy, - Integer transformerStartSequenceToken, Integer transformerEndSequenceToken, - Integer maxQueryTokens, Integer maxDocumentTokens, Integer transformerMaskToken) { +public record OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef, + Optional<Integer> maxTokens, Optional<String> transformerInputIds, + Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds, + Optional<String> transformerOutput, Optional<Boolean> normalize, + Optional<String> executionMode, Optional<Integer> interOpThreads, + Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice, + Optional<String> poolingStrategy, Optional<Integer> transformerStartSequenceToken, + Optional<Integer> transformerEndSequenceToken, Optional<Integer> maxQueryTokens, + Optional<Integer> maxDocumentTokens, Optional<Integer> transformerMaskToken) { - public OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens, - String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds, - String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads, - Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy) { + public OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef, + Optional<Integer> maxTokens, Optional<String> transformerInputIds, + Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds, + Optional<String> transformerOutput, Optional<Boolean> normalize, + Optional<String> executionMode, Optional<Integer> interOpThreads, + Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice, + Optional<String> poolingStrategy) { this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds, transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, - null, null, null, null, null); + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); } - public OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens, - String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds, - String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads, - Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy, - Integer transformerStartSequenceToken, Integer transformerEndSequenceToken) { + public OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef, + Optional<Integer> maxTokens, Optional<String> transformerInputIds, + Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds, + Optional<String> transformerOutput, Optional<Boolean> normalize, + Optional<String> executionMode, Optional<Integer> interOpThreads, + Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice, + Optional<String> poolingStrategy, Optional<Integer> transformerStartSequenceToken, + Optional<Integer> transformerEndSequenceToken) { this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds, transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, - transformerStartSequenceToken, transformerEndSequenceToken, null, null, null); + transformerStartSequenceToken, transformerEndSequenceToken, Optional.empty(), Optional.empty(), Optional.empty()); } public static OnnxModelOptions empty() { - return new OnnxModelOptions(null, null, null, null, null, null, null, null, null, null, null, null, null, - null, null, null, null, null); + return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); } public OnnxModelOptions withExecutionMode(String executionMode) { return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, - transformerTokenTypeIds, transformerOutput, normalize, executionMode, + transformerTokenTypeIds, transformerOutput, normalize, Optional.ofNullable(executionMode), interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, transformerStartSequenceToken, transformerEndSequenceToken, maxQueryTokens, maxDocumentTokens, transformerMaskToken); @@ -49,7 +61,7 @@ public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, public OnnxModelOptions withInteropThreads(Integer interopThreads) { return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds, transformerOutput, normalize, executionMode, - interopThreads, intraOpThreads, gpuDevice, poolingStrategy, + Optional.ofNullable(interopThreads), intraOpThreads, gpuDevice, poolingStrategy, transformerStartSequenceToken, transformerEndSequenceToken, maxQueryTokens, maxDocumentTokens, transformerMaskToken); } @@ -57,7 +69,7 @@ public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, public OnnxModelOptions withIntraopThreads(Integer intraopThreads) { return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds, transformerOutput, normalize, executionMode, - interOpThreads, intraopThreads, gpuDevice, poolingStrategy, + interOpThreads, Optional.ofNullable(intraopThreads), gpuDevice, poolingStrategy, transformerStartSequenceToken, transformerEndSequenceToken, maxQueryTokens, maxDocumentTokens, transformerMaskToken); } @@ -66,7 +78,7 @@ public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) { return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds, transformerOutput, normalize, executionMode, - interOpThreads, intraOpThreads, gpuDevice, poolingStrategy, + interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice), poolingStrategy, transformerStartSequenceToken, transformerEndSequenceToken, maxQueryTokens, maxDocumentTokens, transformerMaskToken); } |