aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorAndrii Yurkiv <dinamicandriy@gmail.com>2023-01-04 22:02:41 +0000
committerAndrii Yurkiv <dinamicandriy@gmail.com>2023-01-04 22:02:41 +0000
commitf9ecee8aa3a93b3cf4cf22cfc5233a52fb697c8d (patch)
treea8a7737c53d055da536bbbf3fbe3ba25a962550a /model-integration
parent261a7682dfef5645431abdc6f458499368ab0f7a (diff)
DJL-based HuggingFaceEmbedder prototype
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml6
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java158
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def18
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java50
4 files changed, 232 insertions, 0 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 43f24301d9a..36a2be1df5a 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -63,6 +63,12 @@
<artifactId>protobuf-java</artifactId>
</dependency>
+ <dependency>
+ <groupId>ai.djl.huggingface</groupId>
+ <artifactId>tokenizers</artifactId>
+ <version>0.20.0</version>
+ </dependency>
+
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
new file mode 100644
index 00000000000..2b9e3a2ab60
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -0,0 +1,158 @@
+package ai.vespa.embedding.huggingface;
+
+import ai.djl.huggingface.tokenizers.Encoding;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+
+import java.io.*;
+import java.nio.file.Paths;
+import java.util.*;
+import java.util.stream.Collectors;
+
+import org.slf4j.LoggerFactory;
+import org.slf4j.Logger;
+
+public class HuggingFaceEmbedder implements Embedder {
+
+ private static final Logger LOG = LoggerFactory.getLogger(HuggingFaceEmbedder.class.getName());
+
+ private final String inputIdsName;
+ private final String attentionMaskName;
+ private final String outputName;
+ private final int maxTokens;
+ private final HuggingFaceTokenizer tokenizer;
+ private final OnnxEvaluator evaluator;
+
+ @Inject
+ public HuggingFaceEmbedder(HuggingFaceEmbedderConfig config) throws IOException {
+ maxTokens = config.transformerMaxTokens();
+ inputIdsName = config.transformerInputIds();
+ attentionMaskName = config.transformerAttentionMask();
+ outputName = config.transformerOutput();
+
+ try {
+ ClassLoader tccl = Thread.currentThread().getContextClassLoader();
+ try {
+ Thread.currentThread().setContextClassLoader(getClass().getClassLoader());
+ tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(config.tokenizerPath().toString()));
+ } finally {
+ Thread.currentThread().setContextClassLoader(tccl);
+ }
+ } catch (IOException e){
+ LOG.info("Could not initialize the tokenizer");
+ throw new IOException("Could not initialize the tokenizer.");
+ }
+ evaluator = new OnnxEvaluator(config.transformerModel().toString());
+ validateModel();
+ }
+
+ public void validateModel() {
+ Map<String, TensorType> inputs = evaluator.getInputInfo();
+ validateName(inputs, inputIdsName, "input");
+ validateName(inputs, attentionMaskName, "input");
+
+ Map<String, TensorType> outputs = evaluator.getOutputInfo();
+ validateName(outputs, outputName, "output");
+ }
+
+ private void validateName(Map<String, TensorType> types, String name, String type) {
+ if ( ! types.containsKey(name)) {
+ throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " +
+ "Model contains: " + String.join(",", types.keySet()));
+ }
+ }
+
+ @Override
+ public List<Integer> embed(String s, Context context) {
+ Encoding encoding = tokenizer.encode(s);
+ List<Integer> tokenIds = longToInteger(encoding.getIds());
+
+ int tokensSize = tokenIds.size();
+
+ if (tokensSize > maxTokens) {
+ Integer lastElement = tokenIds.get(tokensSize - 1);
+ tokenIds = tokenIds.subList(0, maxTokens - 1);
+ tokenIds.add(lastElement);
+ }
+ return tokenIds;
+ }
+
+ public List<Integer> longToInteger(long[] values) {
+ return Arrays.stream(values)
+ .boxed().map(Long::intValue)
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public Tensor embed(String s, Context context, TensorType tensorType) {
+ List<Integer> tokenIds = embed(s.toLowerCase(), context);
+ return embedTokens(tokenIds, tensorType);
+ }
+
+ Tensor embedTokens(List<Integer> tokenIds, TensorType tensorType) {
+ Tensor inputSequence = createTensorRepresentation(tokenIds, "d1");
+ Tensor attentionMask = createAttentionMask(inputSequence);
+
+ Map<String, Tensor> inputs = Map.of(
+ inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0")
+ );
+
+ Map<String, Tensor> outputs = evaluator.evaluate(inputs);
+ Tensor tokenEmbeddings = outputs.get(outputName);
+ Tensor.Builder builder = Tensor.Builder.of(tensorType);
+
+ // Mean pooling implementation
+ Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
+ Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
+ Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
+ for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) {
+ builder.cell(averaged.get(TensorAddress.of(0,i)), i);
+ }
+
+ return normalize(builder.build(), tensorType);
+ }
+
+ Tensor normalize(Tensor embedding, TensorType tensorType) {
+ double sumOfSquares = 0.0;
+
+ Tensor.Builder builder = Tensor.Builder.of(tensorType);
+
+ for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) {
+ double item = embedding.get(TensorAddress.of(i));
+ sumOfSquares += item * item;
+ }
+
+ double magnitude = Math.sqrt(sumOfSquares);
+
+ for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) {
+ double value = embedding.get(TensorAddress.of(i));
+ builder.cell(value / magnitude, i);
+ }
+
+ return builder.build();
+ }
+
+ private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) {
+ int size = input.size();
+ TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
+ for (int i = 0; i < size; ++i) {
+ builder.cell(input.get(i), i);
+ }
+ return builder.build();
+ }
+
+ private Tensor createAttentionMask(Tensor inputSequence) {
+ return inputSequence.map((x) -> 1);
+ }
+
+}
+
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
new file mode 100644
index 00000000000..cf7c3b336d6
--- /dev/null
+++ b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
@@ -0,0 +1,18 @@
+namespace=embedding.huggingface
+
+# Path to tokenizer.json
+tokenizerPath model
+
+# Path to model.onnx
+transformerModel model
+
+# Max length of token sequence model can handle
+transformerMaxTokens int default=512
+
+# Input names
+transformerInputIds string default=input_ids
+transformerAttentionMask string default=attention_mask
+
+# Output name
+transformerOutput string default=last_hidden_state
+
diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
new file mode 100644
index 00000000000..c67b6b0dcab
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
@@ -0,0 +1,50 @@
+package ai.vespa.embedding.huggingface;
+
+import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import com.yahoo.config.ModelReference;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.junit.Assume.assumeTrue;
+import static org.junit.Assert.assertEquals;
+
+public class HuggingFaceEmbedderTest {
+/*
+ @Test
+ public void testEmbedder() {
+
+ String modelPath = "src/test/models/hf/model.onnx";
+ String tokenizerPath = "src/test/models/hf/tokenizer.json";
+ assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+
+ HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
+ builder.tokenizerPath(ModelReference.valueOf(tokenizerPath));
+ builder.transformerModel(ModelReference.valueOf(modelPath));
+
+ HuggingFaceEmbedder embedder;
+
+ try {
+ embedder = new HuggingFaceEmbedder(builder.build());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ TensorType destType = TensorType.fromSpec("tensor<float>(x[768])");
+ List<Integer> tokens = List.of(1,2,3,4,5);
+ Tensor embedding = embedder.embedTokens(tokens, destType);
+
+ System.out.println(embedding);
+
+ Tensor expected = Tensor.from("tensor<float>(x[768]):[-0.025724048, 0.020880165, -0.011260326, -0.023737747, 0.06904736, -0.023877826, -0.020314846, 0.0032329028, -0.015538657, -0.07391539, 0.017203337, -0.011266706, 0.010958312, 0.011904508, -0.013701068, 0.027089471, -0.016722197, -0.020041, 0.021507785, 0.023721753, -0.07874908, 0.011369475, 0.046657883, -0.042779557, 0.048052263, 0.037120715, 0.0012078708, 0.019323641, 0.013024646, 0.061841156, 0.01753008, 1.2066379E-4, 0.023636049, 0.018369958, 0.036082096, 0.03932147, 0.0046853777, -0.015098697, 0.038477935, -0.01895684, -0.040239938, 2.6470664E-4, 0.03997473, 0.02041734, 0.02412652, 0.018273998, 0.017018031, 0.006871845, 0.0025124447, -0.0018908525, -0.013397233, 0.042458713, 0.007796125, 0.028542817, 0.031890307, -0.0074867285, -0.0033081016, -0.02232893, 0.039048433, 0.00957053, 0.06763975, 0.040223297, 0.0064583384, -0.014190483, 0.045714546, 0.0029999055, 0.014651245, 0.024208939, 0.020654708, -0.012122954, -0.0036424815, 0.00488385, 0.029132547, 0.067792565, 0.0075463247, -0.009096316, -0.038455218, 0.015037789, -0.01743026, -0.004400987, -1.5690622E-4, 0.016168159, 0.0020400928, 0.031062322, -0.008158351, 0.0292213, 0.008834568, -0.048937295, -0.00890528, 0.017726518, -0.0067773387, -0.046057213, -0.066518776, 0.0018978252, -0.04398522, 0.011562229, 0.031211298, 0.0103532905, -0.0037940282, 0.093772806, -0.031089822, 0.040764417, -0.053171575, -0.03156361, 0.036163535, 0.03484915, 0.04917469, -0.0045245993, 0.0058647553, 0.05267792, 0.028012566, 0.028851494, -0.022312999, -0.020575663, 0.03345691, 0.025952421, -0.052168794, -0.061676178, -0.017157838, -0.03421253, -0.035815753, -0.06571464, -0.007408596, -2.02389E-4, -0.023351457, -0.055525146, -0.04038344, -0.006495214, -0.017078917, -0.035309125, 0.041886955, -0.09497299, -0.0189574, 0.016921803, 0.017511738, 0.082098976, -0.018675305, 0.033731908, -0.028046045, -0.013675128, 0.0072140736, -0.020495338, 0.009846083, 0.013070329, -0.011773132, 0.035009257, 0.0074090296, -0.014208246, -5.310546E-4, 0.021474011, 0.014579644, -0.09338692, -0.010726686, 0.007154424, 0.057590302, -0.04826717, 0.040737577, 0.014072642, 0.04285114, -0.061159305, 0.013216943, -0.035471566, -0.03792605, 0.015285408, 0.031102464, 0.030012386, -0.023884479, -0.04371121, -0.024413597, -0.010348542, 0.017916787, 0.0042866515, 0.018110914, -0.041588936, 0.024906408, -0.031663194, 0.03195878, -0.06372821, 0.019083183, -0.01137915, -0.018030347, 0.010138715, -0.0582689, 0.031122282, 0.008210103, 0.012292584, 0.027713217, 0.028951935, 0.045635186, -0.009818348, 0.025670283, 0.03957527, -0.028106295, 0.03346287, 0.006125563, 0.013537182, 0.012909673, -0.001204659, 0.018613683, 0.0018722271, -0.019579338, 0.008905144, -0.05733141, 0.025476566, -0.0056283884, 0.017892752, 0.011068579, 0.07707967, -0.024977751, -0.024308717, 0.013858339, -0.0058020353, -0.014463086, -0.009544265, 0.040218975, -0.012510054, 0.04849776, -0.05000309, 0.025404643, 0.008990219, 0.02775138, -0.07551933, 0.008215385, 0.0053623077, -2.8556216E-4, 0.013400637, 0.017384026, -0.016238615, 0.031755704, -0.06869863, 0.0011450738, 0.04904909, 0.0032084947, -0.061084855, 0.005177811, -0.0043256404, 0.015641086, 0.01082181, -0.04075435, 0.014862946, 0.06862344, 0.008437109, -0.016099032, 0.022712294, -0.034809124, -0.03308236, -0.05667152, -0.03971709, 0.021760954, 0.042704564, -0.003670681, -0.0125031965, -0.01086691, 0.0297599, 5.219019E-4, 0.042474877, -0.010456534, 0.08990086, -0.07252977, -0.0232252, -0.032979038, 0.020222792, 0.040868383, 0.06501842, -0.035030693, -0.0015357807, 0.018102454, 0.024944443, -0.020003196, -0.011539847, 0.011255642, 0.037775412, -0.0037286845, -0.0341213, 0.023036147, 0.02926327, -0.046673402, 0.036873233, -0.03849799, 0.05359753, 0.0020826515, 0.006461479, 0.02670649, -0.00140334, 0.033684377, 0.038561035, -0.024399279, 0.002088306, -0.060904354, -0.075068265, -0.06754775, 0.076485276, -0.017709987, 0.046117906, 0.12425809, 0.0106040435, 0.0935674, -0.038158268, 0.009669471, -0.018891279, -0.008584558, 0.062187072, 0.0446559, -0.04003452, 0.021192033, -0.027830705, 0.0030938783, 0.026238382, 0.050908126, -0.0640897, 0.0039400524, -4.0983717E-4, -0.09788098, 0.077888265, -0.008923493, 9.2718634E-4, -0.003174036, -0.0077122, 0.024076542, -0.012247094, 0.015358698, -0.002875235, -0.03378138, -0.015616789, 0.016734147, 0.0035185486, 0.015807444, 0.03484354, 0.053835943, 0.01872425, -0.018600935, 0.0060353098, -0.0033563771, 0.055035062, -0.083564155, -0.011492768, 0.003962845, -0.03442353, 0.09015563, 0.012225138, 0.031516016, 0.030751515, -0.056343056, 0.037657607, 0.08115837, -0.041137557, 0.016311243, -0.058852646, -0.07653154, 0.02130071, 0.0040857317, -0.020951144, -0.0074253944, 0.05309452, -0.026305407, 0.0056941714, -0.02359672, 0.011392254, 0.017097248, -0.021877138, -0.06543879, 0.0428062, 0.023494843, -0.039750084, 0.0198583, 0.039141204, -0.043232452, 0.05673762, -0.00572516, 0.0099977795, -0.010179716, -0.060138825, 0.031860784, 0.0018468671, -0.010174757, 0.02398504, 0.014412493, 0.079279535, -0.015402895, -0.07597795, 0.0087828515, -0.0127440635, -0.008228165, -0.0019640992, -0.028497383, 0.013919859, 0.025142275, -0.1320675, 0.0121768685, -0.046735562, 9.829229E-5, -0.009189184, 0.018436272, -0.08516998, 0.015040611, 0.035327762, -0.010171434, 0.026718847, -0.028313076, -0.013120813, -0.058203585, -0.038716007, 0.022184927, 0.07012223, -0.06264533, 0.056756523, -0.065681, 0.05986038, -0.05279611, -0.054911636, 0.076010436, 0.041015115, 0.03920821, -0.01744772, 0.0034039353, 0.0075382935, -0.01624392, 0.05378706, 0.03231586, -0.07524116, 0.06305631, 0.05991506, -5.444081E-4, 0.013409323, -0.06888001, -0.040708184, 0.03734671, 0.0052551595, 0.010684721, -0.040529408, 0.028915955, 0.029105747, -0.020185236, 0.06496445, -0.022009412, -0.0033808595, 0.024795303, 0.0026664098, 0.042996325, -0.04022965, 0.012088627, -0.0223725, -0.015508588, -0.013264377, -0.020301288, -0.0015037537, 0.007726907, -0.0022741442, -0.044956572, 0.010999487, -0.0014431779, 0.031763487, 0.019383159, -0.010809799, -0.0134113515, -0.02977723, -0.0014747303, 0.04057383, -0.015751097, -0.011753722, -0.036123946, 0.018938705, 9.906364E-4, 0.036280718, -0.09332089, -0.009991581, 0.025463797, 0.05119224, 0.07540358, 0.027900526, 0.100351, 0.030668264, -0.007963987, -0.029012676, 0.021057166, -0.009048951, 0.00842427, 0.01876811, -0.035510283, 0.034366164, -0.019845309, -0.042352304, 0.061529007, 0.033723388, -0.003314133, -0.024003353, 0.028756566, 0.059479274, -0.064037204, -0.049339823, -8.226961E-4, -0.020002557, -0.011994202, 0.015570834, 0.045298383, 0.0057346253, 0.09007624, -0.053770024, 0.007630297, 0.020868106, -0.017037094, -0.055875137, 0.04900269, 0.015741454, 0.0124805225, -0.0018614308, -0.019576045, 0.023860257, 0.017991606, 0.003367343, 0.06020378, 0.0026180628, -0.09462455, -0.0070169405, -0.029571567, -0.038119137, 0.013861453, -0.017994085, -0.045172486, -0.022872778, 0.055174, -0.008971932, -0.004308986, 0.01601522, 0.003778432, 0.031744134, 0.02868899, -0.14191957, -0.016329547, -0.016410846, -4.6470436E-6, -0.001020947, 0.0027826065, -0.039300438, -0.011893471, -0.02075158, -0.010576237, -0.02062336, 0.013781222, -0.008120074, -0.029703692, -0.046667382, 0.043274097, -0.021984896, -0.02135883, 0.018591158, -0.041193772, -0.0059216945, -0.0011121663, -0.02494825, 0.017716935, -0.009277854, 0.04252703, -0.025771331, -0.04950817, -0.010750714, -0.03249349, -0.051454652, 0.013961526, 0.020731043, 0.005106143, -0.00143041, 0.026762294, -0.040144447, -0.017221546, -0.024441173, 0.026409082, -0.02006987, -0.06430974, 0.03596783, 0.11877633, 0.019118857, -0.023766126, -0.07279529, 0.09964732, -0.021428458, 0.026640266, 0.022268405, 0.042921524, -0.007858052, -0.09624318, -0.022612294, -0.019523097, 0.03567699, 0.03789931, -0.006097838, 0.02569811, 0.0191861, 0.07499048, -0.071985036, 0.02195141, -0.025485674, 0.027281731, 0.028316619, -2.0592185E-4, -0.0087429015, 0.03162398, -0.007593867, -0.008025583, 0.010998485, 0.0040793577, -0.0013161482, 0.04318332, 0.021368232, 0.019170962, 0.021635167, 0.004988852, -0.013367873, -0.012466818, -0.0046749967, -0.03768797, 0.039707363, -0.044927754, -0.03654003, -0.023658205, -0.001842112, -0.010652133, 0.011228231, 9.927069E-4, -0.037655, -3.4657202E-4, -4.4477347E-6, -0.0016849868, -0.08615711, 0.048710153, -0.041956488, 0.043102454, 0.039763212, 0.013289194, -0.080720246, 0.0059994697, 0.015247406, -0.04542366, -0.05336339, -0.054322492, -0.012767407, -0.004596957, -0.025987137, -0.0020473057, -0.007264475, -0.026240809, 0.004853881, 0.010054818, -0.021872481, 0.04792254, -0.017764855, 0.01646331, 0.027268302, 0.042611707, -0.03171807, -0.040693402, -0.021686986, -0.011264477, 0.0067759645, 0.02997798, 0.008916376, 0.02419584, -0.0020763963, 0.0056151943, -0.0026876493, 0.05944909, -0.045404088, 0.018879034, 0.011747689, 0.03196524, -5.1519996E-4, 0.0013718596, 0.058480065, 0.057530686, 0.0032917305, -0.03556252, -0.03199946, 8.2321337E-4, 0.008931801, 0.086252205, 0.0013950337, 0.024274798, -0.009235701, -0.016323563, 0.0069916663, 0.015588893, 0.07079948, 0.019281829, 0.028265161, -0.028714187, 0.041323867, -0.021685174, 0.037033204, 0.040014476, 0.066936225, 0.033902515, -0.0027583768, 0.0102592, -0.025718369, -0.023265323, 0.038798634, 7.6241576E-4, -0.038741548, -0.008511498, 0.0066514956, -0.0047597503, -0.013024812, 0.020948282, 0.032426294, -0.04831275, 0.023370765, -0.026260225, 0.00937463, -0.0136523675, -0.010202122, -0.019113116, 0.02264367, 0.061504897, 0.005174083, 0.009410467, -0.048552092, 0.018883549, 0.017368691, 0.075853914, -0.044583194, -0.018234642, 0.030739859, 0.03802531, -0.0039164764, -0.0034495012, -0.02906537, 0.01969172, 0.0039183535, 0.10471505, 0.0139206285, -0.023619942, -0.062056284, -0.03643699, 0.0075695068, -0.078179, -0.0035506163, 0.055161994, 0.029487751, 0.033565205, -0.034224838, 0.033817288, -0.041085172, -0.017652633, -0.023406055, -0.040258575, -0.02416118, 0.0065094046, -0.034261346, -0.02321457, 0.050748765, 0.0348932, -0.0054060123, 0.052658632, 0.027222686, -0.011120133, 0.026567513, 0.013036436, -0.051871147, -0.062004875, 0.03265022, -0.003253459, -0.047073938, -0.0069678905, 0.04008311, 0.02167116, 0.0023027628, -0.008902592, -0.032181825]");
+
+ assertEquals(embedding, expected);
+
+ }
+*/
+}