diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 81150fe99b0..bad4bb5c9b3 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -3,22 +3,25 @@ package ai.vespa.embedding.huggingface; import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.annotation.Inject; +import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import com.yahoo.jdisc.AbstractResource; import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.io.*; +import java.io.IOException; import java.nio.file.Paths; -import java.util.*; +import java.util.Arrays; +import java.util.List; +import java.util.Map; -import org.slf4j.LoggerFactory; -import org.slf4j.Logger; - -public class HuggingFaceEmbedder implements Embedder { +public class HuggingFaceEmbedder extends AbstractResource implements Embedder { private static final Logger LOG = LoggerFactory.getLogger(HuggingFaceEmbedder.class.getName()); @@ -30,7 +33,7 @@ public class HuggingFaceEmbedder implements Embedder { private final OnnxEvaluator evaluator; @Inject - public HuggingFaceEmbedder(HuggingFaceEmbedderConfig config) throws IOException { + public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException { maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); @@ -48,7 +51,7 @@ public class HuggingFaceEmbedder implements Embedder { LOG.info("Could not initialize the tokenizer"); throw new IOException("Could not initialize the tokenizer."); } - evaluator = new OnnxEvaluator(config.transformerModel().toString()); + evaluator = onnx.evaluatorOf(config.transformerModel().toString()); validateModel(); } @@ -83,6 +86,8 @@ public class HuggingFaceEmbedder implements Embedder { return tokenIds; } + @Override protected void destroy() { evaluator.close(); } + public List<Integer> longToInteger(long[] values) { return Arrays.stream(values) .boxed().map(Long::intValue) |