1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
/**
* @author bjorncs
*/
public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer {
private final ModelReference modelRef;
private final ModelReference vocabRef;
private final Integer maxTokens;
private final String transformerInputIds;
private final String transformerAttentionMask;
private final String transformerTokenTypeIds;
private final String transformerOutput;
private final Boolean normalize;
private final String onnxExecutionMode;
private final Integer onnxInteropThreads;
private final Integer onnxIntraopThreads;
private final Integer onnxGpuDevice;
private final String poolingStrategy;
public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
modelRef = model.modelReference();
vocabRef = Model.fromXml(state, xml, "tokenizer-model")
.map(Model::modelReference)
.orElseGet(() -> resolveDefaultVocab(model, state));
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null);
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null);
onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
model.registerOnnxModelCost(cluster);
}
private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
var modelId = model.modelId().orElse(null);
if (state.isHosted() && modelId != null) {
return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference();
}
throw new IllegalArgumentException("'tokenizer-model' must be specified");
}
@Override
public void getConfig(HuggingFaceEmbedderConfig.Builder b) {
b.transformerModel(modelRef).tokenizerPath(vocabRef);
if (maxTokens != null) b.transformerMaxTokens(maxTokens);
if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
if (transformerOutput != null) b.transformerOutput(transformerOutput);
if (normalize != null) b.normalize(normalize);
if (onnxExecutionMode != null) b.transformerExecutionMode(
HuggingFaceEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode));
if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy));
}
}
|