aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/schema/processing/TensorFieldTestCase.java
blob: b7817900dac9a490db43697b727f0e2cc382225d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.processing;

import com.yahoo.schema.document.Attribute;
import com.yahoo.schema.parser.ParseException;
import org.junit.jupiter.api.Test;


import static com.yahoo.schema.ApplicationBuilder.createFromString;
import static com.yahoo.config.model.test.TestUtil.joinLines;
import static org.junit.jupiter.api.Assertions.*;

/**
 * @author geirst
 */
public class TensorFieldTestCase {

    @Test
    void requireThatTensorFieldCannotBeOfCollectionType() throws ParseException {
        try {
            createFromString(getSd("field f1 type array<tensor(x{})> {}"));
            fail("Expected exception");
        }
        catch (IllegalArgumentException e) {
            assertEquals("For schema 'test', field 'f1': A field with collection type of tensor is not supported. Use simple type 'tensor' instead.",
                    e.getMessage());
        }
    }

    @Test
    void requireThatTensorFieldCannotBeIndexField() throws ParseException {
        try {
            createFromString(getSd("field f1 type tensor(x{}) { indexing: index }"));
            fail("Expected exception");
        }
        catch (IllegalArgumentException e) {
            assertEquals("For schema 'test', field 'f1': A tensor of type 'tensor(x{})' does not support having an 'index'. " +
                    "Currently, only tensors with 1 indexed dimension or 1 mapped + 1 indexed dimension support that.",
                    e.getMessage());
        }
    }

    @Test
    void requireThatIndexedTensorAttributeCannotBeFastSearch() throws ParseException {
        try {
            createFromString(getSd("field f1 type tensor(x[3]) { indexing: attribute \n attribute: fast-search }"));
            fail("Expected exception");
        }
        catch (IllegalArgumentException e) {
            assertEquals("For schema 'test', field 'f1': An attribute of type 'tensor' cannot be 'fast-search'.", e.getMessage());
        }
    }

    @Test
    void requireThatIndexedTensorAttributeCannotBeFastRank() throws ParseException {
        try {
            createFromString(getSd("field f1 type tensor(x[3]) { indexing: attribute \n attribute: fast-rank }"));
            fail("Expected exception");
        }
        catch (IllegalArgumentException e) {
            assertEquals("The attribute 'f1' (tensor(x[3])) does not support 'fast-rank'. Only supported for tensor types with at least one mapped dimension", e.getMessage());
        }
    }

    @Test
    void requireThatIllegalTensorTypeSpecThrowsException() throws ParseException {
        try {
            createFromString(getSd("field f1 type tensor(invalid) { indexing: attribute }"));
            fail("Expected exception");
        }
        catch (IllegalArgumentException e) {
            assertStartsWith("Field type: Illegal tensor type spec:", e.getMessage());
        }
    }

    @Test
    void hnsw_index_is_default_turned_off() throws ParseException {
        var attr = getAttributeFromSd("field t1 type tensor(x[64]) { indexing: attribute }", "t1");
        assertFalse(attr.hnswIndexParams().isPresent());
    }

    @Test
    void hnsw_index_gets_default_parameters_if_not_specified() throws ParseException {
        assertHnswIndexParams("", 16, 200);
        assertHnswIndexParams("index: hnsw", 16, 200);
    }

    @Test
    void tensor_with_one_mapped_and_one_indexed_dimension_can_have_hnsw_index() throws ParseException {
        assertHnswIndexParams("tensor(x{},y[64])", "", 16, 200);
        assertHnswIndexParams("tensor(x[64],y{})", "", 16, 200);
    }

    @Test
    void hnsw_index_parameters_can_be_specified() throws ParseException {
        assertHnswIndexParams("index { hnsw { max-links-per-node: 32 } }", 32, 200);
        assertHnswIndexParams("index { hnsw { neighbors-to-explore-at-insert: 300 } }", 16, 300);
        assertHnswIndexParams(joinLines("index {",
                "  hnsw {",
                "    max-links-per-node: 32",
                "    neighbors-to-explore-at-insert: 300",
                "  }",
                "}"),
                32, 300);
    }

    @Test
    void tensor_with_hnsw_index_must_be_an_attribute() throws ParseException {
        try {
            createFromString(getSd("field t1 type tensor(x[64]) { indexing: index }"));
            fail("Expected exception");
        }
        catch (IllegalArgumentException e) {
            assertEquals("For schema 'test', field 't1': A tensor that has an index must also be an attribute.", e.getMessage());
        }
    }

    @Test
    void tensor_with_hnsw_index_parameters_must_be_an_index() throws ParseException {
        try {
            createFromString(getSd(joinLines(
                    "field t1 type tensor(x[64]) {",
                    "  indexing: attribute ",
                    "  index {",
                    "    hnsw { max-links-per-node: 32 }",
                    "  }",
                    "}")));
            fail("Expected exception");
        }
        catch (IllegalArgumentException e) {
            assertEquals("For schema 'test', field 't1': " +
                    "A tensor that specifies hnsw index parameters must also specify 'index' in 'indexing'",
                    e.getMessage());
        }
    }

    @Test
    void tensors_with_at_least_one_mapped_dimension_can_be_direct() throws ParseException {
        assertTrue(getAttributeFromSd(
                "field t1 type tensor(x{}) { indexing: attribute \n attribute: fast-search }", "t1").isFastSearch());
        assertTrue(getAttributeFromSd(
                "field t1 type tensor(x{},y{},z[4]) { indexing: attribute \n attribute: fast-search }", "t1").isFastSearch());
    }

    @Test
    void tensors_with_at_least_one_mapped_dimension_can_be_fast_rank() throws ParseException {
        assertTrue(getAttributeFromSd(
                "field t1 type tensor(x{}) { indexing: attribute \n attribute: fast-rank }", "t1").isFastRank());
        assertTrue(getAttributeFromSd(
                "field t1 type tensor(x{},y{},z[4]) { indexing: attribute \n attribute: fast-rank }", "t1").isFastRank());
    }

    private static String getSd(String field) {
        return joinLines("search test {",
                "  document test {",
                "    " + field,
                "  }",
                "}");
    }

    private Attribute getAttributeFromSd(String fieldSpec, String attrName) throws ParseException {
        return createFromString(getSd(fieldSpec)).getSchema().getAttribute(attrName);
    }

    private void assertHnswIndexParams(String indexSpec, int maxLinksPerNode, int neighborsToExploreAtInsert) throws ParseException {
        assertHnswIndexParams("tensor(x[64])", indexSpec, maxLinksPerNode, neighborsToExploreAtInsert);
    }

    private void assertHnswIndexParams(String tensorType, String indexSpec, int maxLinksPerNode, int neighborsToExploreAtInsert) throws ParseException {
        var sd = getSdWithIndexSpec(tensorType, indexSpec);
        var search = createFromString(sd).getSchema();
        var attr = search.getAttribute("t1");
        var params = attr.hnswIndexParams();
        assertTrue(params.isPresent());
        assertEquals(maxLinksPerNode, params.get().maxLinksPerNode());
        assertEquals(neighborsToExploreAtInsert, params.get().neighborsToExploreAtInsert());
    }

    private String getSdWithIndexSpec(String tensorType, String indexSpec) {
        return getSd(joinLines("field t1 type " + tensorType + " {",
                "  indexing: attribute | index",
                "  " + indexSpec,
                "}"));
    }

    private void assertStartsWith(String prefix, String string) {
        assertEquals(prefix, string.substring(0, Math.min(prefix.length(), string.length())));
    }

}