summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2023-11-16 11:19:28 +0100
committerHarald Musum <musum@yahooinc.com>2023-11-16 11:19:28 +0100
commit2d1ba9b48bdad4c1af7058a37821e447c89bb420 (patch)
treee2d912a1a0a174834196e967b420440f2d7f32c7 /config-model
parentd20e4909d59116b231c4b25fba35f22e673e407b (diff)
Include only options that control how an Onnx model is loaded in OnnxModelOptions
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java54
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java68
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java52
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java77
5 files changed, 114 insertions, 141 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 0d2c695280e..867ffdb3960 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -143,7 +143,7 @@ public class OnnxModel extends DistributableResource implements Cloneable {
public void setStatelessInterOpThreads(int interOpThreads) {
if (interOpThreads >= 0) {
- onnxModelOptions = onnxModelOptions.withInteropThreads(interOpThreads);
+ onnxModelOptions = onnxModelOptions.withInterOpThreads(interOpThreads);
}
}
@@ -153,7 +153,7 @@ public class OnnxModel extends DistributableResource implements Cloneable {
public void setStatelessIntraOpThreads(int intraOpThreads) {
if (intraOpThreads >= 0) {
- onnxModelOptions = onnxModelOptions.withIntraopThreads(intraOpThreads);
+ onnxModelOptions = onnxModelOptions.withIntraOpThreads(intraOpThreads);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
index 78177fd7d57..ea3caadc23a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -1,13 +1,12 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.component;
+import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
-import java.util.Optional;
-
import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode;
import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy;
import static com.yahoo.text.XML.getChildValue;
@@ -19,40 +18,49 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer {
private final OnnxModelOptions onnxModelOptions;
+ private final ModelReference modelRef;
+ private final ModelReference vocabRef;
+ private final Integer maxTokens;
+ private final String transformerInputIds;
+ private final String transformerAttentionMask;
+ private final String transformerTokenTypeIds;
+ private final String transformerOutput;
+ private final Integer transformerStartSequenceToken;
+ private final Integer transformerEndSequenceToken;
+ private final String poolingStrategy;
public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
- Optional.of(model.modelReference()),
- Optional.of(Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference()),
- getChildValue(xml, "max-tokens").map(Integer::parseInt),
- getChildValue(xml, "transformer-input-ids"),
- getChildValue(xml, "transformer-attention-mask"),
- getChildValue(xml, "transformer-token-type-ids"),
- getChildValue(xml, "transformer-output"),
- getChildValue(xml, "normalize").map(Boolean::parseBoolean),
getChildValue(xml, "onnx-execution-mode"),
getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
- getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new),
- getChildValue(xml, "pooling-strategy"),
- getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt),
- getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt));
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
+ modelRef = model.modelReference();
+ vocabRef = Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference();
+ maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
+ transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
+ transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
+ transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null);
+ transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
+ transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
+ transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
+ poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
model.registerOnnxModelCost(cluster);
}
@Override
public void getConfig(BertBaseEmbedderConfig.Builder b) {
- b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerVocab(onnxModelOptions.vocabRef().get());
- onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens);
- onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds);
- onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask);
- onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds);
- onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput);
- onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken);
- onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken);
- onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value)));
+ b.transformerModel(modelRef).tokenizerVocab(vocabRef);
+ if (maxTokens != null) b.transformerMaxTokens(maxTokens);
+ if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
+ if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
+ if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
+ if (transformerOutput != null) b.transformerOutput(transformerOutput);
+ if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken);
+ if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
+ if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy));
onnxModelOptions.executionMode().ifPresent(value -> b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(value)));
onnxModelOptions.interOpThreads().ifPresent(b::onnxInterOpThreads);
onnxModelOptions.intraOpThreads().ifPresent(b::onnxIntraOpThreads);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index 33753094f69..cbae50b400c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -7,8 +7,6 @@ import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
-import java.util.Optional;
-
import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
@@ -20,31 +18,43 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer {
private final OnnxModelOptions onnxModelOptions;
+ private final ModelReference modelRef;
+ private final ModelReference vocabRef;
+
+ private final Integer maxQueryTokens;
+
+ private final Integer maxDocumentTokens;
+
+ private final Integer transformerStartSequenceToken;
+ private final Integer transformerEndSequenceToken;
+ private final Integer transformerMaskToken;
+ private final Integer maxTokens;
+ private final String transformerInputIds;
+ private final String transformerAttentionMask;
+
+ private final String transformerOutput;
public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
- Optional.of(model.modelReference()),
- Optional.of(Model.fromXml(state, xml, "tokenizer-model")
- .map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state))),
- getChildValue(xml, "max-tokens").map(Integer::parseInt),
- getChildValue(xml, "transformer-input-ids"),
- getChildValue(xml, "transformer-attention-mask"),
- getChildValue(xml, "transformer-token-type-ids"),
- getChildValue(xml, "transformer-output"),
- getChildValue(xml, "normalize").map(Boolean::parseBoolean),
getChildValue(xml, "onnx-execution-mode"),
getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
- getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new),
- getChildValue(xml, "pooling-strategy"),
- getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt),
- getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt),
- getChildValue(xml, "max-query-tokens").map(Integer::parseInt),
- getChildValue(xml, "max-document-tokens").map(Integer::parseInt),
- getChildValue(xml, "transformer-mask-token").map(Integer::parseInt));
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
+ modelRef = model.modelReference();
+ vocabRef = Model.fromXml(state, xml, "tokenizer-model")
+ .map(Model::modelReference)
+ .orElseGet(() -> resolveDefaultVocab(model, state));
+ maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
+ maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null);
+ maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null);
+ transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
+ transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
+ transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null);
+ transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
+ transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
+ transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
model.registerOnnxModelCost(cluster);
}
@@ -58,16 +68,16 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
@Override
public void getConfig(ColBertEmbedderConfig.Builder b) {
- b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerPath(onnxModelOptions.vocabRef().get());
- onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens);
- onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds);
- onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask);
- onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput);
- onnxModelOptions.maxQueryTokens().ifPresent(b::maxQueryTokens);
- onnxModelOptions.maxDocumentTokens().ifPresent(b::maxDocumentTokens);
- onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken);
- onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken);
- onnxModelOptions.transformerMaskToken().ifPresent(b::transformerMaskToken);
+ b.transformerModel(modelRef).tokenizerPath(vocabRef);
+ if (maxTokens != null) b.transformerMaxTokens(maxTokens);
+ if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
+ if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
+ if (transformerOutput != null) b.transformerOutput(transformerOutput);
+ if (maxQueryTokens != null) b.maxQueryTokens(maxQueryTokens);
+ if (maxDocumentTokens != null) b.maxDocumentTokens(maxDocumentTokens);
+ if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken);
+ if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
+ if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken);
onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index af6febeb1e5..d1bd0dce000 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -7,8 +7,6 @@ import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
-import java.util.Optional;
-
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode;
import static com.yahoo.text.XML.getChildValue;
@@ -21,27 +19,35 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer {
private final OnnxModelOptions onnxModelOptions;
+ private final ModelReference modelRef;
+ private final ModelReference vocabRef;
+ private final Integer maxTokens;
+ private final String transformerInputIds;
+ private final String transformerAttentionMask;
+ private final String transformerTokenTypeIds;
+ private final String transformerOutput;
+ private final Boolean normalize;
+ private final String poolingStrategy;
public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
- var modelRef = model.modelReference();
this.onnxModelOptions = new OnnxModelOptions(
- Optional.of(modelRef),
- Optional.of(Model.fromXml(state, xml, "tokenizer-model")
- .map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state))),
- getChildValue(xml, "max-tokens").map(Integer::parseInt),
- getChildValue(xml, "transformer-input-ids"),
- getChildValue(xml, "transformer-attention-mask"),
- getChildValue(xml, "transformer-token-type-ids"),
- getChildValue(xml, "transformer-output"),
- getChildValue(xml, "normalize").map(Boolean::parseBoolean),
getChildValue(xml, "onnx-execution-mode"),
getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
- getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new),
- getChildValue(xml, "pooling-strategy"));
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
+ modelRef = model.modelReference();
+ vocabRef = Model.fromXml(state, xml, "tokenizer-model")
+ .map(Model::modelReference)
+ .orElseGet(() -> resolveDefaultVocab(model, state));
+ maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
+ transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
+ transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
+ transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null);
+ transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
+ normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
+ poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
model.registerOnnxModelCost(cluster);
}
@@ -55,18 +61,18 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
@Override
public void getConfig(HuggingFaceEmbedderConfig.Builder b) {
- b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerPath(onnxModelOptions.vocabRef().get());
- onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens);
- onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds);
- onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask);
- onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds);
- onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput);
- onnxModelOptions.normalize().ifPresent(b::normalize);
+ b.transformerModel(modelRef).tokenizerPath(vocabRef);
+ if (maxTokens != null) b.transformerMaxTokens(maxTokens);
+ if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
+ if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
+ if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
+ if (transformerOutput != null) b.transformerOutput(transformerOutput);
+ if (normalize != null) b.normalize(normalize);
+ if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy));
onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
- onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value)));
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java
index 0efb8a71fe4..6347f0dc427 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java
@@ -1,86 +1,35 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.component;
-import com.yahoo.config.ModelReference;
-
import java.util.Optional;
/**
+ * Onnx model options that are relevant when deciding if an Onnx model needs to be reloaded. If any of the
+ * values in this class change, reload is needed.
+ *
* @author hmusum
*/
-public record OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef,
- Optional<Integer> maxTokens, Optional<String> transformerInputIds,
- Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds,
- Optional<String> transformerOutput, Optional<Boolean> normalize,
- Optional<String> executionMode, Optional<Integer> interOpThreads,
- Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice,
- Optional<String> poolingStrategy, Optional<Integer> transformerStartSequenceToken,
- Optional<Integer> transformerEndSequenceToken, Optional<Integer> maxQueryTokens,
- Optional<Integer> maxDocumentTokens, Optional<Integer> transformerMaskToken) {
-
-
- public OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef,
- Optional<Integer> maxTokens, Optional<String> transformerInputIds,
- Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds,
- Optional<String> transformerOutput, Optional<Boolean> normalize,
- Optional<String> executionMode, Optional<Integer> interOpThreads,
- Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice,
- Optional<String> poolingStrategy) {
- this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds,
- transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
- Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
- }
+public record OnnxModelOptions(Optional<String> executionMode, Optional<Integer> interOpThreads,
+ Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice) {
- public OnnxModelOptions(Optional<ModelReference> modelRef, Optional<ModelReference> vocabRef,
- Optional<Integer> maxTokens, Optional<String> transformerInputIds,
- Optional<String> transformerAttentionMask, Optional<String> transformerTokenTypeIds,
- Optional<String> transformerOutput, Optional<Boolean> normalize,
- Optional<String> executionMode, Optional<Integer> interOpThreads,
- Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice,
- Optional<String> poolingStrategy, Optional<Integer> transformerStartSequenceToken,
- Optional<Integer> transformerEndSequenceToken) {
- this(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask, transformerTokenTypeIds,
- transformerOutput, normalize, executionMode, interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
- transformerStartSequenceToken, transformerEndSequenceToken, Optional.empty(), Optional.empty(), Optional.empty());
- }
-
- public static OnnxModelOptions empty() {
- return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(),
- Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(),
- Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
+ public static OnnxModelOptions empty() {
+ return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
}
public OnnxModelOptions withExecutionMode(String executionMode) {
- return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
- transformerTokenTypeIds, transformerOutput, normalize, Optional.ofNullable(executionMode),
- interOpThreads, intraOpThreads, gpuDevice, poolingStrategy,
- transformerStartSequenceToken, transformerEndSequenceToken,
- maxQueryTokens, maxDocumentTokens, transformerMaskToken);
+ return new OnnxModelOptions(Optional.ofNullable(executionMode), interOpThreads, intraOpThreads, gpuDevice);
}
- public OnnxModelOptions withInteropThreads(Integer interopThreads) {
- return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
- transformerTokenTypeIds, transformerOutput, normalize, executionMode,
- Optional.ofNullable(interopThreads), intraOpThreads, gpuDevice, poolingStrategy,
- transformerStartSequenceToken, transformerEndSequenceToken,
- maxQueryTokens, maxDocumentTokens, transformerMaskToken);
+ public OnnxModelOptions withInterOpThreads(Integer interOpThreads) {
+ return new OnnxModelOptions(executionMode, Optional.ofNullable(interOpThreads), intraOpThreads, gpuDevice);
}
- public OnnxModelOptions withIntraopThreads(Integer intraopThreads) {
- return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
- transformerTokenTypeIds, transformerOutput, normalize, executionMode,
- interOpThreads, Optional.ofNullable(intraopThreads), gpuDevice, poolingStrategy,
- transformerStartSequenceToken, transformerEndSequenceToken,
- maxQueryTokens, maxDocumentTokens, transformerMaskToken);
+ public OnnxModelOptions withIntraOpThreads(Integer intraOpThreads) {
+ return new OnnxModelOptions(executionMode, interOpThreads, Optional.ofNullable(intraOpThreads), gpuDevice);
}
-
public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) {
- return new OnnxModelOptions(modelRef, vocabRef, maxTokens, transformerInputIds, transformerAttentionMask,
- transformerTokenTypeIds, transformerOutput, normalize, executionMode,
- interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice), poolingStrategy,
- transformerStartSequenceToken, transformerEndSequenceToken,
- maxQueryTokens, maxDocumentTokens, transformerMaskToken);
+ return new OnnxModelOptions(executionMode, interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice));
}
public record GpuDevice(int deviceNumber, boolean required) {