diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java | 58 |
1 files changed, 30 insertions, 28 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index 67edb40b4d3..78177fd7d57 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -6,6 +6,8 @@ import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import java.util.Optional; + import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode; import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy; import static com.yahoo.text.XML.getChildValue; @@ -22,39 +24,39 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( - model.modelReference(), - Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(), - getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null), - getChildValue(xml, "transformer-input-ids").orElse(null), - getChildValue(xml, "transformer-attention-mask").orElse(null), - getChildValue(xml, "transformer-token-type-ids").orElse(null), - getChildValue(xml, "transformer-output").orElse(null), - getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null), - getChildValue(xml, "onnx-execution-mode").orElse(null), - getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null), - getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null), - getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new).orElse(null), - getChildValue(xml, "pooling-strategy").orElse(null), - getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null), - getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null)); + Optional.of(model.modelReference()), + Optional.of(Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference()), + getChildValue(xml, "max-tokens").map(Integer::parseInt), + getChildValue(xml, "transformer-input-ids"), + getChildValue(xml, "transformer-attention-mask"), + getChildValue(xml, "transformer-token-type-ids"), + getChildValue(xml, "transformer-output"), + getChildValue(xml, "normalize").map(Boolean::parseBoolean), + getChildValue(xml, "onnx-execution-mode"), + getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new), + getChildValue(xml, "pooling-strategy"), + getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt), + getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt)); model.registerOnnxModelCost(cluster); } @Override public void getConfig(BertBaseEmbedderConfig.Builder b) { - b.transformerModel(onnxModelOptions.modelRef()).tokenizerVocab(onnxModelOptions.vocabRef()); - if (onnxModelOptions.maxTokens() != null) b.transformerMaxTokens(onnxModelOptions.maxTokens()); - if (onnxModelOptions.transformerInputIds() != null) b.transformerInputIds(onnxModelOptions.transformerInputIds()); - if (onnxModelOptions.transformerAttentionMask() != null) b.transformerAttentionMask(onnxModelOptions.transformerAttentionMask()); - if (onnxModelOptions.transformerTokenTypeIds() != null) b.transformerTokenTypeIds(onnxModelOptions.transformerTokenTypeIds()); - if (onnxModelOptions.transformerOutput() != null) b.transformerOutput(onnxModelOptions.transformerOutput()); - if (onnxModelOptions.transformerStartSequenceToken() != null) b.transformerStartSequenceToken(onnxModelOptions.transformerStartSequenceToken()); - if (onnxModelOptions.transformerEndSequenceToken() != null) b.transformerEndSequenceToken(onnxModelOptions.transformerEndSequenceToken()); - if (onnxModelOptions.poolingStrategy() != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(onnxModelOptions.poolingStrategy())); - if (onnxModelOptions.executionMode() != null) b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(onnxModelOptions.executionMode())); - if (onnxModelOptions.interOpThreads() != null) b.onnxInterOpThreads(onnxModelOptions.interOpThreads()); - if (onnxModelOptions.intraOpThreads() != null) b.onnxIntraOpThreads(onnxModelOptions.intraOpThreads()); - if (onnxModelOptions.gpuDevice() != null) b.onnxGpuDevice(onnxModelOptions.gpuDevice().deviceNumber()); + b.transformerModel(onnxModelOptions.modelRef().get()).tokenizerVocab(onnxModelOptions.vocabRef().get()); + onnxModelOptions.maxTokens().ifPresent(b::transformerMaxTokens); + onnxModelOptions.transformerInputIds().ifPresent(b::transformerInputIds); + onnxModelOptions.transformerAttentionMask().ifPresent(b::transformerAttentionMask); + onnxModelOptions.transformerTokenTypeIds().ifPresent(b::transformerTokenTypeIds); + onnxModelOptions.transformerOutput().ifPresent(b::transformerOutput); + onnxModelOptions.transformerStartSequenceToken().ifPresent(b::transformerStartSequenceToken); + onnxModelOptions.transformerEndSequenceToken().ifPresent(b::transformerEndSequenceToken); + onnxModelOptions.poolingStrategy().ifPresent(value -> b.poolingStrategy(PoolingStrategy.Enum.valueOf(value))); + onnxModelOptions.executionMode().ifPresent(value -> b.onnxExecutionMode(OnnxExecutionMode.Enum.valueOf(value))); + onnxModelOptions.interOpThreads().ifPresent(b::onnxInterOpThreads); + onnxModelOptions.intraOpThreads().ifPresent(b::onnxIntraOpThreads); + onnxModelOptions.gpuDevice().ifPresent(value -> b.onnxGpuDevice(value.deviceNumber())); } } |