aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
blob: d504d77cc9b853a3b46452a7ecb9da5370a257c5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.embedding.huggingface;


import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TensorAddress;
import org.junit.Test;

import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeTrue;
import static org.junit.jupiter.api.Assertions.*;

import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode;

public class HuggingFaceEmbedderTest {

    static HuggingFaceEmbedder embedder = getEmbedder();
    static HuggingFaceEmbedder normalizedEmbedder = getNormalizedEmbedder();

    @Test
    public void testBinarization() {
        TensorType typeOne = TensorType.fromSpec("tensor<int8>(x[1])");
        TensorType typeTwo = TensorType.fromSpec("tensor<int8>(x[2])");
        assertPackRight("tensor(x[8]):[0,0,0,0,0,0,0,0]", "tensor<int8>(x[1]):[0]", typeOne);
        assertPackRight("tensor(x[8]):[1,1,1,1,1,1,1,1]", "tensor<int8>(x[1]):[-1]", typeOne);
        assertPackRight("tensor(x[16]):[0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1]", "tensor<int8>(x[2]):[0, -1]", typeTwo);

        assertPackRight("tensor(x[8]):[0,1,0,1,0,1,0,1]", "tensor<int8>(x[1]):[85]", typeOne);
        assertPackRight("tensor(x[8]):[1,0,1,0,1,0,1,0]", "tensor<int8>(x[1]):[-86]", typeOne);
        assertPackRight("tensor(x[16]):[0,1,0,1,0,1,0,1,1,0,1,0,1,0,1,0]", "tensor<int8>(x[2]):[85, -86]", typeTwo);

        assertPackRight("tensor(x[8]):[1,1,1,1,0,0,0,0]", "tensor<int8>(x[1]):[-16]", typeOne);
        assertPackRight("tensor(x[8]):[0,0,0,0,1,1,1,1]", "tensor<int8>(x[1]):[15]", typeOne);
        assertPackRight("tensor(x[16]):[1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1]", "tensor<int8>(x[2]):[-16, 15]", typeTwo);
    }

    private void assertPackRight(String input, String expected, TensorType type) {
        Tensor inputTensor = Tensor.from(input);
        Tensor result = HuggingFaceEmbedder.binarize((IndexedTensor) inputTensor, type);
        assertEquals(expected, result.toString());
        //Verify that the unpack_bits ranking feature produce compatible output
        Tensor unpacked = expandBitTensor(result);
        assertEquals(inputTensor.toString(), unpacked.toString());
    }

    @Test
    public void testCaching() {
        var context = new Embedder.Context("schema.indexing");
        var myEmbedderId = "my-hf-embedder";
        context.setEmbedderId(myEmbedderId);

        var input = "This is a test string to embed";
        Tensor result = embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[8])"));
        HuggingFaceEmbedder.HFEmbedderCacheKey key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, input);
        var modelOuput = context.getCachedValue(key);
        assertNotNull(modelOuput);

        Tensor binaryResult = embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(x[4])"));
        var modelOuput2 = context.getCachedValue(key);
        assertEquals(modelOuput, modelOuput2);
        assertNotEquals(result, binaryResult);

        var anotherInput = "This is a different test string to embed with the same embedder";
        embedder.embed(anotherInput, context,TensorType.fromSpec("tensor<float>(x[4])"));
        key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, anotherInput);
        var modelOuput3 = context.getCachedValue(key);
        assertNotEquals(modelOuput, modelOuput3);

        //context cache is shared
        var copyContext = context.copy();
        var anotherEmbedderId = "another-hf-embedder";
        copyContext.setEmbedderId(anotherEmbedderId);
        key = new HuggingFaceEmbedder.HFEmbedderCacheKey(anotherEmbedderId, input);
        assertNull(copyContext.getCachedValue(key));
        embedder.embed(input, copyContext,TensorType.fromSpec("tensor<int8>(x[2])"));
        assertNotEquals(modelOuput, copyContext.getCachedValue(key));
    }
    @Test
    public void testEmbedder() {
        var context = new Embedder.Context("schema.indexing");
        String input = "This is a test";
        Tensor expected = Tensor.from("tensor<float>(x[8]):[-0.666, 0.335, 0.227, 0.0919, -0.069, 0.323, 0.422, 0.270]");
        Tensor result = embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
        for(int i = 0; i < 8; i++) {
            assertEquals(expected.get(TensorAddress.of(i)), result.get(TensorAddress.of(i)), 1e-2);
        }
        // Thresholding on the above gives [0, 1, 1, 1, 0, 1, 1, 1] which is packed into 119 (int8)
        Tensor binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[1])")));
        assertEquals("tensor<int8>(x[1]):[119]", binarizedResult.toString());

        binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])")));
        assertEquals("tensor<int8>(x[2]):[119, 44]", binarizedResult.toAbbreviatedString());

        binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[48])")));
        assertTrue(binarizedResult.toAbbreviatedString().startsWith("tensor<int8>(x[48]):[119, 44"));

        assertThrows(IllegalArgumentException.class, () -> {
            // throws because the target tensor type is not compatible with the model output
            //49*8 > 384
            embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[49])")));
        });
        Tensor float16Result = embedder.embed(input, context, TensorType.fromSpec(("tensor<bfloat16>(x[1])")));
        assertEquals(-0.666, float16Result.sum().asDouble(),1e-3);
    }

    @Test
    public void testEmbedderWithNormalization() {
        String input = "This is a test";
        var context = new Embedder.Context("schema.indexing");
        Tensor result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
        assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
        result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[16])")));
        assertEquals(1.0,  result.multiply(result).sum().asDouble(), 1e-3);
        Tensor binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])")));
        assertEquals("tensor<int8>(x[2]):[119, 44]", binarizedResult.toAbbreviatedString());
    }

    @Test
    public void testThatWrongTensorTypeThrows() {
        var context = new Embedder.Context("schema.indexing");
        String input = "This is a test";
        assertThrows(IllegalArgumentException.class, () -> {
            // throws because the target tensor type is mapped
           embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{})")));
        });
        assertThrows(IllegalArgumentException.class, () -> {
            // throws because the target tensor is 0d
            embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[0]")));
        });
        assertThrows(IllegalArgumentException.class, () -> {
            // throws because the target tensor is 2d
            embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{}, y[2])")));
        });
    }

    private static HuggingFaceEmbedder getEmbedder() {
        String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
        String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";
        assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
        HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
        builder.tokenizerPath(ModelReference.valueOf(vocabPath));
        builder.transformerModel(ModelReference.valueOf(modelPath));
        builder.transformerGpuDevice(-1);
        return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
    }
    private static HuggingFaceEmbedder getNormalizedEmbedder() {
        String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
        String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";
        assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
        HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
        builder.tokenizerPath(ModelReference.valueOf(vocabPath));
        builder.transformerModel(ModelReference.valueOf(modelPath));
        builder.transformerGpuDevice(-1);
        builder.normalize(true);
        return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
    }

    public static Tensor expandBitTensor(Tensor packed) {
        var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.DOUBLE, "big");
        var context = new MapContext();
        context.put("input", new TensorValue(packed));
        return unpacker.evaluate(context).asTensor();
    }
}