aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java47
1 files changed, 47 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
index d504d77cc9b..c2c37db31f6 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
@@ -143,6 +143,39 @@ public class HuggingFaceEmbedderTest {
});
}
+ @Test
+ public void testEmbedderWithNormalizationAndPrefix() {
+ String input = "This is a test";
+ var context = new Embedder.Context("schema.indexing");
+ Tensor result = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
+ assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
+ result = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor<float>(x[16])")));
+ assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
+ Tensor binarizedResult = getNormalizePrefixdEmbedder().embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])")));
+ assertEquals("tensor<int8>(x[2]):[125, 44]", binarizedResult.toAbbreviatedString());
+
+ var queryContext = new Embedder.Context("query.qt");
+ Tensor queryResult = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor<float>(x[8])")));
+ assertEquals(1.0, queryResult.multiply(queryResult).sum().asDouble(), 1e-3);
+ queryResult = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor<float>(x[16])")));
+ assertEquals(1.0, queryResult.multiply(queryResult).sum().asDouble(), 1e-3);
+ Tensor binarizedResultQuery = getNormalizePrefixdEmbedder().embed(input, queryContext, TensorType.fromSpec(("tensor<int8>(x[2])")));
+ assertNotEquals(binarizedResult.toAbbreviatedString(), binarizedResultQuery.toAbbreviatedString());
+ assertEquals("tensor<int8>(x[2]):[119, -116]", binarizedResultQuery.toAbbreviatedString());
+ }
+
+ @Test
+ public void testPrepend() {
+ var context = new Embedder.Context("schema.indexing");
+ String input = "This is a test";
+ var embedder = getNormalizePrefixdEmbedder();
+ var result = embedder.prependInstruction(input, context);
+ assertEquals("This is a document: This is a test", result);
+ var queryContext = new Embedder.Context("query.qt");
+ var queryResult = embedder.prependInstruction(input, queryContext);
+ assertEquals("Represent this text: This is a test", queryResult);
+ }
+
private static HuggingFaceEmbedder getEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";
@@ -165,6 +198,20 @@ public class HuggingFaceEmbedderTest {
return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
}
+ private static HuggingFaceEmbedder getNormalizePrefixdEmbedder() {
+ String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
+ String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
+ HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
+ builder.tokenizerPath(ModelReference.valueOf(vocabPath));
+ builder.transformerModel(ModelReference.valueOf(modelPath));
+ builder.transformerGpuDevice(-1);
+ builder.normalize(true);
+ builder.prependQuery("Represent this text:");
+ builder.prependDocument("This is a document:");
+ return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ }
+
public static Tensor expandBitTensor(Tensor packed) {
var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.DOUBLE, "big");
var context = new MapContext();