aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
blob: a32170b0a631dc27f033050835413824766a4b78 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.embedding;

import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeTrue;

public class ColBertEmbedderTest {

    @Test
    public void testPacking() {
        assertPackedRight(
                "" +
                        "tensor<float>(d1[6],d2[8]):" +
                        "[" +
                        "[0, 0, 0, 0, 0, 0, 0, 1]," +
                        "[0, 0, 0, 0, 0, 1, 0, 1]," +
                        "[0, 0, 0, 0, 0, 0, 1, 1]," +
                        "[0, 1, 1, 1, 1, 1, 1, 1]," +
                        "[1, 0, 0, 0, 0, 0, 0, 0]," +
                        "[1, 1, 1, 1, 1, 1, 1, 1]" +
                        "]",
                    TensorType.fromSpec("tensor<int8>(dt{},x[1])"),
                "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}", 6
        );
        assertPackedRight(
                "" +
                        "tensor<float>(d1[2],d2[16]):" +
                        "[" +
                        "[0, 0, 0, 0, 0, 0, 0, 1,   1, 0, 0, 0, 0, 0, 0, 0]," +
                        "[0, 0, 0, 0, 0, 1, 0, 1,   0, 0, 0, 0, 0, 0, 0, 1]" +
                        "]",
                TensorType.fromSpec("tensor<int8>(dt{},x[2])"),
                "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2
        );
    }

    @Test
    public void testEmbedder() {
        assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext);
        assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext);
        assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext);

        assertThrows(IllegalArgumentException.class, () -> {
            //throws because int8 is not supported for query context
            assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext);
        });
        assertThrows(IllegalArgumentException.class, () -> {
            //throws because 16 is less than model output (128) and we want float
            assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext);
        });

        assertThrows(IllegalArgumentException.class, () -> {
            //throws because 128/8 does not fit into 15
            assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext);
        });
    }

    @Test
    public void testLenghtLimits() {
        StringBuilder sb = new StringBuilder();
        for(int i = 0; i < 1024; i++) {
            sb.append("annoyance");
            sb.append(" ");
        }
        String text = sb.toString();
        Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
        assertEquals(511*128,fullFloat.size());

        Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext);
        assertEquals(32*128,query.size());

        Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext);
        assertEquals(511*16,binaryRep.size());

        Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext);
        // 3 tokens, 16 bytes each = 48 bytes
        //CLS [unused1] sequence
        assertEquals(3*16,shortDoc.size());;
    }

    static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) {
        TensorType destType = TensorType.fromSpec(tensorSpec);
        Tensor result = embedder.embed(text, context, destType);
        assertEquals(destType,result.type());
        MixedTensor mixedTensor = (MixedTensor) result;
        if(context == queryContext) {
            assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size());
        }
        return result;
    }

    static void assertPackedRight(String numbers, TensorType destination,String  expected, int size) {
        Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination, size);
        assertEquals(expected,packed.toString());
    }

    static final Embedder embedder;
    static final Embedder.Context indexingContext;
    static final Embedder.Context queryContext;
    static {
        indexingContext = new Embedder.Context("schema.indexing");
        queryContext = new Embedder.Context("query(qt)");
        embedder = getEmbedder();
    }
    private static Embedder getEmbedder() {
        String vocabPath = "src/test/models/onnx/transformer/tokenizer.json";
        String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
        assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
        ColBertEmbedderConfig.Builder builder = new ColBertEmbedderConfig.Builder();
        builder.tokenizerPath(ModelReference.valueOf(vocabPath));
        builder.transformerModel(ModelReference.valueOf(modelPath));
        builder.transformerGpuDevice(-1);
        return  new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
    }
}