summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2023-11-16 10:45:20 +0100
committerHarald Musum <musum@yahooinc.com>2023-11-16 10:45:20 +0100
commitd20e4909d59116b231c4b25fba35f22e673e407b (patch)
tree30447ea12070c1c4fe7f209f0f1fb8fb7208cb18 /config-model
parentae5b333a133add3fc6d7ab6c056379c5b4f1564d (diff)
Use Optionals in OnnxModelOptions
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java8
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java58
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java68
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java54
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java58
5 files changed, 132 insertions, 114 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index b22e737fd6a..0d2c695280e 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -138,7 +138,7 @@ public class OnnxModel extends DistributableResource implements Cloneable {
}
public Optional<String> getStatelessExecutionMode() {
- return Optional.ofNullable(onnxModelOptions.executionMode());
+ return onnxModelOptions.executionMode();
}
public void setStatelessInterOpThreads(int interOpThreads) {
@@ -148,7 +148,7 @@ public class OnnxModel extends DistributableResource implements Cloneable {
}
public Optional<Integer> getStatelessInterOpThreads() {
- return Optional.ofNullable(onnxModelOptions.interOpThreads());
+ return onnxModelOptions.interOpThreads();
}
public void setStatelessIntraOpThreads(int intraOpThreads) {
@@ -158,7 +158,7 @@ public class OnnxModel extends DistributableResource implements Cloneable {
}
public Optional<Integer> getStatelessIntraOpThreads() {
- return Optional.ofNullable(onnxModelOptions.intraOpThreads());
+ return onnxModelOptions.intraOpThreads();
}
public void setGpuDevice(int deviceNumber, boolean required) {
@@ -168,7 +168,7 @@ public class OnnxModel extends DistributableResource implements Cloneable {
}
public Optional<OnnxModelOptions.GpuDevice> getGpuDevice() {
- return Optional.ofNullable(onnxModelOptions.gpuDevice());
+ return onnxModelOptions.gpuDevice();
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
index 67edb40b4d3..78177fd7d57 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -6,6 +6,8 @@ import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import java.util.Optional;
+
import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode;
import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy;
import static com.yahoo.text.XML.getChildValue;
@@ -22,39 +24,39 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf
super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
- model.modelReference(),
- Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(),
- getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "transformer-input-ids").orElse(null),
- getChildValue(xml, "transformer-attention-mask").orElse(null),
- getChildValue(xml, "transformer-token-type-ids").orElse(null),
- getChildValue(xml, "transformer-output").orElse(null),
- getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null),
- getChildValue(xml, "onnx-execution-mode").orElse(null),
- getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null),
- getChildValue(xml, "pooling-strategy").orElse(null),
- getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null));
+ Optional.of(model.modelReference()),
+ Optional.of(Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference()),
+ getChildValue(xml, "max-tokens").map(Integer::parseInt),
+ getChildValue(xml, "transformer-input-ids"),
+ getChildValue(xml, "transformer-attention-mask"),
+ getChildValue(xml, "transformer-token-type-ids"),
+ getChildValue(xml, "transformer-output"),
+ getChildValue(xml, "normalize").map(Boolean::parseBoolean),
+ getChildValue(xml, "onnx-execution-mode"),
+ getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new),
+ getChildValue(xml, "pooling-strategy"),
+ getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt),
+ getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt));
model.registerOnnxModelCost(cluster);
}
@Override
public void getConfig(BertBaseEmbedderConfig.Builder b) {
- b.transformerModel(onnxModelOptions.modelRef()).tokenizerVocab(onnxModelOptions.vocabRef());
- if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens());
- if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds());
- if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask());
- if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds());
- if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput());
- if (onnxModelOptions.transformerStartSequenceToken() != null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken());
- if (onnxModelOptions.transformerEndSequenceToken() != null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken());
- if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy()));
- if (onnxModelOptions.executionMode() != null) b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(onnxModelOptions.executionMode()));
- if (onnxModelOptions.interOpThreads() != null) b.onnxInterOpThreads(onnxModelOptions.interOpThreads());
- if (onnxModelOptions.intraOpThreads() != null) b.onnxIntraOpThreads(onnxModelOptions.intraOpThreads());
- if (onnxModelOptions.gpuDevice() != null) b.onnxGpuDevice(onnxModelOptions.gpuDevice().deviceNumber());
+ b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerVocab(onnxModelOptions.vocabRef().get());
+ onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens);
+ onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds);
+ onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask);
+ onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds);
+ onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput);
+ onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken);
+ onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken);
+ onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value)));
+ onnxModelOptions.executionMode().ifPresent(value -> b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(value)));
+ onnxModelOptions.interOpThreads().ifPresent(b::onnxInterOpThreads);
+ onnxModelOptions.intraOpThreads().ifPresent(b::onnxIntraOpThreads);
+ onnxModelOptions.gpuDevice().ifPresent(value -> b.onnxGpuDevice(value.deviceNumber()));
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index 5192c806797..33753094f69 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -7,6 +7,8 @@ import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import java.util.Optional;
+
import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
@@ -23,26 +25,26 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
- model.modelReference(),
- Model.fromXml(state, xml, "tokenizer-model")
+ Optional.of(model.modelReference()),
+ Optional.of(Model.fromXml(state, xml, "tokenizer-model")
.map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state)),
- getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "transformer-input-ids").orElse(null),
- getChildValue(xml, "transformer-attention-mask").orElse(null),
- getChildValue(xml, "transformer-token-type-ids").orElse(null),
- getChildValue(xml, "transformer-output").orElse(null),
- getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null),
- getChildValue(xml, "onnx-execution-mode").orElse(null),
- getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null),
- getChildValue(xml, "pooling-strategy").orElse(null),
- getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null));
+ .orElseGet(() -> resolveDefaultVocab(model, state))),
+ getChildValue(xml, "max-tokens").map(Integer::parseInt),
+ getChildValue(xml, "transformer-input-ids"),
+ getChildValue(xml, "transformer-attention-mask"),
+ getChildValue(xml, "transformer-token-type-ids"),
+ getChildValue(xml, "transformer-output"),
+ getChildValue(xml, "normalize").map(Boolean::parseBoolean),
+ getChildValue(xml, "onnx-execution-mode"),
+ getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new),
+ getChildValue(xml, "pooling-strategy"),
+ getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt),
+ getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt),
+ getChildValue(xml, "max-query-tokens").map(Integer::parseInt),
+ getChildValue(xml, "max-document-tokens").map(Integer::parseInt),
+ getChildValue(xml, "transformer-mask-token").map(Integer::parseInt));
model.registerOnnxModelCost(cluster);
}
@@ -56,20 +58,20 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
@Override
public void getConfig(ColBertEmbedderConfig.Builder b) {
- b.transformerModel(onnxModelOptions.modelRef()).tokenizerPath(onnxModelOptions.vocabRef());
- if (onnxModelOptions.maxTokens() !=null) b.transformerMaxTokens(onnxModelOptions.maxTokens());
- if (onnxModelOptions.transformerInputIds() !=null) b.transformerInputIds(onnxModelOptions.transformerInputIds());
- if (onnxModelOptions.transformerAttentionMask() !=null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask());
- if (onnxModelOptions.transformerOutput() !=null) b.transformerOutput(onnxModelOptions.transformerOutput());
- if (onnxModelOptions.maxQueryTokens() !=null) b.maxQueryTokens(onnxModelOptions.maxQueryTokens());
- if (onnxModelOptions.maxDocumentTokens() !=null) b.maxDocumentTokens(onnxModelOptions.maxDocumentTokens());
- if (onnxModelOptions.transformerStartSequenceToken() !=null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken());
- if (onnxModelOptions.transformerEndSequenceToken() !=null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken());
- if (onnxModelOptions.transformerMaskToken() !=null) b.transformerMaskToken(onnxModelOptions.transformerMaskToken());
- if (onnxModelOptions.executionMode() !=null) b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(onnxModelOptions.executionMode()));
- if (onnxModelOptions.interOpThreads() !=null) b.transformerInterOpThreads(onnxModelOptions.interOpThreads());
- if (onnxModelOptions.intraOpThreads() !=null) b.transformerIntraOpThreads(onnxModelOptions.intraOpThreads());
- if (onnxModelOptions.gpuDevice() !=null) b.transformerGpuDevice(onnxModelOptions.gpuDevice().deviceNumber());
+ b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerPath(onnxModelOptions.vocabRef().get());
+ onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens);
+ onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds);
+ onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask);
+ onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput);
+ onnxModelOptions.maxQueryTokens().ifPresent(b::maxQueryTokens);
+ onnxModelOptions.maxDocumentTokens().ifPresent(b::maxDocumentTokens);
+ onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken);
+ onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken);
+ onnxModelOptions.transformerMaskToken().ifPresent(b::transformerMaskToken);
+ onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
+ onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
+ onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
+ onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index 5bf30174539..af6febeb1e5 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -7,6 +7,8 @@ import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import java.util.Optional;
+
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode;
import static com.yahoo.text.XML.getChildValue;
@@ -25,21 +27,21 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
var modelRef = model.modelReference();
this.onnxModelOptions = new OnnxModelOptions(
- modelRef,
- Model.fromXml(state, xml, "tokenizer-model")
+ Optional.of(modelRef),
+ Optional.of(Model.fromXml(state, xml, "tokenizer-model")
.map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state)),
- getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "transformer-input-ids").orElse(null),
- getChildValue(xml, "transformer-attention-mask").orElse(null),
- getChildValue(xml, "transformer-token-type-ids").orElse(null),
- getChildValue(xml, "transformer-output").orElse(null),
- getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null),
- getChildValue(xml, "onnx-execution-mode").orElse(null),
- getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null),
- getChildValue(xml, "pooling-strategy").orElse(null));
+ .orElseGet(() -> resolveDefaultVocab(model, state))),
+ getChildValue(xml, "max-tokens").map(Integer::parseInt),
+ getChildValue(xml, "transformer-input-ids"),
+ getChildValue(xml, "transformer-attention-mask"),
+ getChildValue(xml, "transformer-token-type-ids"),
+ getChildValue(xml, "transformer-output"),
+ getChildValue(xml, "normalize").map(Boolean::parseBoolean),
+ getChildValue(xml, "onnx-execution-mode"),
+ getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new),
+ getChildValue(xml, "pooling-strategy"));
model.registerOnnxModelCost(cluster);
}
@@ -53,18 +55,18 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
@Override
public void getConfig(HuggingFaceEmbedderConfig.Builder b) {
- b.transformerModel(onnxModelOptions.modelRef()).tokenizerPath(onnxModelOptions.vocabRef());
- if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens());
- if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds());
- if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask());
- if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds());
- if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput());
- if (onnxModelOptions.normalize() != null) b.normalize(onnxModelOptions.normalize());
- if (onnxModelOptions.executionMode() != null) b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(onnxModelOptions.executionMode()));
- if (onnxModelOptions.interOpThreads() != null) b.transformerInterOpThreads(onnxModelOptions.interOpThreads());
- if (onnxModelOptions.intraOpThreads() != null) b.transformerIntraOpThreads(onnxModelOptions.intraOpThreads());
- if (onnxModelOptions.gpuDevice() != null) b.transformerGpuDevice(onnxModelOptions.gpuDevice().deviceNumber());
- if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy()));
+ b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerPath(onnxModelOptions.vocabRef().get());
+ onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens);
+ onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds);
+ onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask);
+ onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds);
+ onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput);
+ onnxModelOptions.normalize().ifPresent(b::normalize);
+ onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
+ onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
+ onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
+ onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
+ onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value)));
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java
index 6b573e870fa..0efb8a71fe4 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java
@@ -3,44 +3,56 @@ package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
+import java.util.Optional;
+
/**
* @author hmusum
*/
-public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens,
- String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds,
- String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads,
- Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy,
- Integer transformerStartSequenceToken, Integer transformerEndSequenceToken,
- Integer maxQueryTokens, Integer maxDocumentTokens, Integer transformerMaskToken) {
+public record OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef,
+ Optional<Integer> maxTokens, Optional<String> transformerInputIds,
+ Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds,
+ Optional<String> transformerOutput, Optional<Boolean> normalize,
+ Optional<String> executionMode, Optional<Integer> interOpThreads,
+ Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice,
+ Optional<String> poolingStrategy, Optional<Integer> transformerStartSequenceToken,
+ Optional<Integer> transformerEndSequenceToken, Optional<Integer> maxQueryTokens,
+ Optional<Integer> maxDocumentTokens, Optional<Integer> transformerMaskToken) {
- public OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens,
- String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds,
- String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads,
- Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy) {
+ public OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef,
+ Optional<Integer> maxTokens, Optional<String> transformerInputIds,
+ Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds,
+ Optional<String> transformerOutput, Optional<Boolean> normalize,
+ Optional<String> executionMode, Optional<Integer> interOpThreads,
+ Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice,
+ Optional<String> poolingStrategy) {
this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds,
transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
- null, null, null, null, null);
+ Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
}
- public OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef, Integer maxTokens,
- String transformerInputIds, String transformerAttentionMask, String transformerTokenTypeIds,
- String transformerOutput, Boolean normalize, String executionMode, Integer interOpThreads,
- Integer intraOpThreads, GpuDevice gpuDevice, String poolingStrategy,
- Integer transformerStartSequenceToken, Integer transformerEndSequenceToken) {
+ public OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef,
+ Optional<Integer> maxTokens, Optional<String> transformerInputIds,
+ Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds,
+ Optional<String> transformerOutput, Optional<Boolean> normalize,
+ Optional<String> executionMode, Optional<Integer> interOpThreads,
+ Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice,
+ Optional<String> poolingStrategy, Optional<Integer> transformerStartSequenceToken,
+ Optional<Integer> transformerEndSequenceToken) {
this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds,
transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
- transformerStartSequenceToken, transformerEndSequenceToken, null, null, null);
+ transformerStartSequenceToken, transformerEndSequenceToken, Optional.empty(), Optional.empty(), Optional.empty());
}
public static OnnxModelOptions empty() {
- return new OnnxModelOptions(null, null, null, null, null, null, null, null, null, null, null, null, null,
- null, null, null, null, null);
+ return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(),
+ Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(),
+ Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
}
public OnnxModelOptions withExecutionMode(String executionMode) {
return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
- transformerTokenTypeIds, transformerOutput, normalize, executionMode,
+ transformerTokenTypeIds, transformerOutput, normalize, Optional.ofNullable(executionMode),
interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
transformerStartSequenceToken, transformerEndSequenceToken,
maxQueryTokens, maxDocumentTokens, transformerMaskToken);
@@ -49,7 +61,7 @@ public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef,
public OnnxModelOptions withInteropThreads(Integer interopThreads) {
return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
transformerTokenTypeIds, transformerOutput, normalize, executionMode,
- interopThreads, intraOpThreads, gpuDevice, poolingStrategy,
+ Optional.ofNullable(interopThreads), intraOpThreads, gpuDevice, poolingStrategy,
transformerStartSequenceToken, transformerEndSequenceToken,
maxQueryTokens, maxDocumentTokens, transformerMaskToken);
}
@@ -57,7 +69,7 @@ public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef,
public OnnxModelOptions withIntraopThreads(Integer intraopThreads) {
return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
transformerTokenTypeIds, transformerOutput, normalize, executionMode,
- interOpThreads, intraopThreads, gpuDevice, poolingStrategy,
+ interOpThreads, Optional.ofNullable(intraopThreads), gpuDevice, poolingStrategy,
transformerStartSequenceToken, transformerEndSequenceToken,
maxQueryTokens, maxDocumentTokens, transformerMaskToken);
}
@@ -66,7 +78,7 @@ public record OnnxModelOptions(ModelReference modelRef, ModelReference vocabRef,
public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) {
return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
transformerTokenTypeIds, transformerOutput, normalize, executionMode,
- interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
+ interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice), poolingStrategy,
transformerStartSequenceToken, transformerEndSequenceToken,
maxQueryTokens, maxDocumentTokens, transformerMaskToken);
}