summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java58
1 files changed, 30 insertions, 28 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
index 67edb40b4d3..78177fd7d57 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -6,6 +6,8 @@ import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import java.util.Optional;
+
import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode;
import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy;
import static com.yahoo.text.XML.getChildValue;
@@ -22,39 +24,39 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf
super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
- model.modelReference(),
- Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(),
- getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "transformer-input-ids").orElse(null),
- getChildValue(xml, "transformer-attention-mask").orElse(null),
- getChildValue(xml, "transformer-token-type-ids").orElse(null),
- getChildValue(xml, "transformer-output").orElse(null),
- getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null),
- getChildValue(xml, "onnx-execution-mode").orElse(null),
- getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null),
- getChildValue(xml, "pooling-strategy").orElse(null),
- getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null),
- getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null));
+ Optional.of(model.modelReference()),
+ Optional.of(Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference()),
+ getChildValue(xml, "max-tokens").map(Integer::parseInt),
+ getChildValue(xml, "transformer-input-ids"),
+ getChildValue(xml, "transformer-attention-mask"),
+ getChildValue(xml, "transformer-token-type-ids"),
+ getChildValue(xml, "transformer-output"),
+ getChildValue(xml, "normalize").map(Boolean::parseBoolean),
+ getChildValue(xml, "onnx-execution-mode"),
+ getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new),
+ getChildValue(xml, "pooling-strategy"),
+ getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt),
+ getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt));
model.registerOnnxModelCost(cluster);
}
@Override
public void getConfig(BertBaseEmbedderConfig.Builder b) {
- b.transformerModel(onnxModelOptions.modelRef()).tokenizerVocab(onnxModelOptions.vocabRef());
- if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens());
- if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds());
- if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask());
- if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds());
- if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput());
- if (onnxModelOptions.transformerStartSequenceToken() != null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken());
- if (onnxModelOptions.transformerEndSequenceToken() != null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken());
- if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy()));
- if (onnxModelOptions.executionMode() != null) b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(onnxModelOptions.executionMode()));
- if (onnxModelOptions.interOpThreads() != null) b.onnxInterOpThreads(onnxModelOptions.interOpThreads());
- if (onnxModelOptions.intraOpThreads() != null) b.onnxIntraOpThreads(onnxModelOptions.intraOpThreads());
- if (onnxModelOptions.gpuDevice() != null) b.onnxGpuDevice(onnxModelOptions.gpuDevice().deviceNumber());
+ b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerVocab(onnxModelOptions.vocabRef().get());
+ onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens);
+ onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds);
+ onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask);
+ onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds);
+ onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput);
+ onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken);
+ onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken);
+ onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value)));
+ onnxModelOptions.executionMode().ifPresent(value -> b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(value)));
+ onnxModelOptions.interOpThreads().ifPresent(b::onnxInterOpThreads);
+ onnxModelOptions.intraOpThreads().ifPresent(b::onnxIntraOpThreads);
+ onnxModelOptions.gpuDevice().ifPresent(value -> b.onnxGpuDevice(value.deviceNumber()));
}
}