aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java23
1 files changed, 14 insertions, 9 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 81150fe99b0..bad4bb5c9b3 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -3,22 +3,25 @@ package ai.vespa.embedding.huggingface;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.annotation.Inject;
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.jdisc.AbstractResource;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-import java.io.*;
+import java.io.IOException;
import java.nio.file.Paths;
-import java.util.*;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
-import org.slf4j.LoggerFactory;
-import org.slf4j.Logger;
-
-public class HuggingFaceEmbedder implements Embedder {
+public class HuggingFaceEmbedder extends AbstractResource implements Embedder {
private static final Logger LOG = LoggerFactory.getLogger(HuggingFaceEmbedder.class.getName());
@@ -30,7 +33,7 @@ public class HuggingFaceEmbedder implements Embedder {
private final OnnxEvaluator evaluator;
@Inject
- public HuggingFaceEmbedder(HuggingFaceEmbedderConfig config) throws IOException {
+ public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException {
maxTokens = config.transformerMaxTokens();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
@@ -48,7 +51,7 @@ public class HuggingFaceEmbedder implements Embedder {
LOG.info("Could not initialize the tokenizer");
throw new IOException("Could not initialize the tokenizer.");
}
- evaluator = new OnnxEvaluator(config.transformerModel().toString());
+ evaluator = onnx.evaluatorOf(config.transformerModel().toString());
validateModel();
}
@@ -83,6 +86,8 @@ public class HuggingFaceEmbedder implements Embedder {
return tokenIds;
}
+ @Override protected void destroy() { evaluator.close(); }
+
public List<Integer> longToInteger(long[] values) {
return Arrays.stream(values)
.boxed().map(Long::intValue)