summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/schema/processing/TensorFieldTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/schema/processing/TensorFieldTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/TensorFieldTestCase.java172
1 files changed, 172 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/TensorFieldTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/TensorFieldTestCase.java
new file mode 100644
index 00000000000..67c77508e3b
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/schema/processing/TensorFieldTestCase.java
@@ -0,0 +1,172 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.processing;
+
+import com.yahoo.schema.document.Attribute;
+import com.yahoo.schema.parser.ParseException;
+import org.junit.Test;
+
+
+import static com.yahoo.schema.ApplicationBuilder.createFromString;
+import static com.yahoo.config.model.test.TestUtil.joinLines;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/**
+ * @author geirst
+ */
+public class TensorFieldTestCase {
+
+ @Test
+ public void requireThatTensorFieldCannotBeOfCollectionType() throws ParseException {
+ try {
+ createFromString(getSd("field f1 type array<tensor(x{})> {}"));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("For schema 'test', field 'f1': A field with collection type of tensor is not supported. Use simple type 'tensor' instead.",
+ e.getMessage());
+ }
+ }
+
+ @Test
+ public void requireThatTensorFieldCannotBeIndexField() throws ParseException {
+ try {
+ createFromString(getSd("field f1 type tensor(x{}) { indexing: index }"));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("For schema 'test', field 'f1': A tensor of type 'tensor(x{})' does not support having an 'index'. " +
+ "Currently, only tensors with 1 indexed dimension supports that.",
+ e.getMessage());
+ }
+ }
+
+ @Test
+ public void requireThatIndexedTensorAttributeCannotBeFastSearch() throws ParseException {
+ try {
+ createFromString(getSd("field f1 type tensor(x[3]) { indexing: attribute \n attribute: fast-search }"));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("For schema 'test', field 'f1': An attribute of type 'tensor' cannot be 'fast-search'.", e.getMessage());
+ }
+ }
+
+ @Test
+ public void requireThatIllegalTensorTypeSpecThrowsException() throws ParseException {
+ try {
+ createFromString(getSd("field f1 type tensor(invalid) { indexing: attribute }"));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertStartsWith("Field type: Illegal tensor type spec:", e.getMessage());
+ }
+ }
+
+ @Test
+ public void hnsw_index_is_default_turned_off() throws ParseException {
+ var attr = getAttributeFromSd("field t1 type tensor(x[64]) { indexing: attribute }", "t1");
+ assertFalse(attr.hnswIndexParams().isPresent());
+ }
+
+ @Test
+ public void hnsw_index_gets_default_parameters_if_not_specified() throws ParseException {
+ assertHnswIndexParams("", 16, 200);
+ assertHnswIndexParams("index: hnsw", 16, 200);
+ }
+
+ @Test
+ public void hnsw_index_parameters_can_be_specified() throws ParseException {
+ assertHnswIndexParams("index { hnsw { max-links-per-node: 32 } }", 32, 200);
+ assertHnswIndexParams("index { hnsw { neighbors-to-explore-at-insert: 300 } }", 16, 300);
+ assertHnswIndexParams(joinLines("index {",
+ " hnsw {",
+ " max-links-per-node: 32",
+ " neighbors-to-explore-at-insert: 300",
+ " }",
+ "}"),
+ 32, 300);
+ }
+
+ @Test
+ public void tensor_with_hnsw_index_must_be_an_attribute() throws ParseException {
+ try {
+ createFromString(getSd("field t1 type tensor(x[64]) { indexing: index }"));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("For schema 'test', field 't1': A tensor that has an index must also be an attribute.", e.getMessage());
+ }
+ }
+
+ @Test
+ public void tensor_with_hnsw_index_parameters_must_be_an_index() throws ParseException {
+ try {
+ createFromString(getSd(joinLines(
+ "field t1 type tensor(x[64]) {",
+ " indexing: attribute ",
+ " index {",
+ " hnsw { max-links-per-node: 32 }",
+ " }",
+ "}")));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("For schema 'test', field 't1': " +
+ "A tensor that specifies hnsw index parameters must also specify 'index' in 'indexing'",
+ e.getMessage());
+ }
+ }
+
+ @Test
+ public void tensors_with_at_least_one_mapped_dimension_can_be_direct() throws ParseException {
+ assertTrue(getAttributeFromSd(
+ "field t1 type tensor(x{}) { indexing: attribute \n attribute: fast-search }", "t1").isFastSearch());
+ assertTrue(getAttributeFromSd(
+ "field t1 type tensor(x{},y{},z[4]) { indexing: attribute \n attribute: fast-search }", "t1").isFastSearch());
+ }
+
+ @Test
+ public void tensors_with_at_least_one_mapped_dimension_can_be_fast_rank() throws ParseException {
+ assertTrue(getAttributeFromSd(
+ "field t1 type tensor(x{}) { indexing: attribute \n attribute: fast-rank }", "t1").isFastRank());
+ assertTrue(getAttributeFromSd(
+ "field t1 type tensor(x{},y{},z[4]) { indexing: attribute \n attribute: fast-rank }", "t1").isFastRank());
+ }
+
+ private static String getSd(String field) {
+ return joinLines("search test {",
+ " document test {",
+ " " + field,
+ " }",
+ "}");
+ }
+
+ private Attribute getAttributeFromSd(String fieldSpec, String attrName) throws ParseException {
+ return createFromString(getSd(fieldSpec)).getSchema().getAttribute(attrName);
+ }
+
+ private void assertHnswIndexParams(String indexSpec, int maxLinksPerNode, int neighborsToExploreAtInsert) throws ParseException {
+ var sd = getSdWithIndexSpec(indexSpec);
+ var search = createFromString(sd).getSchema();
+ var attr = search.getAttribute("t1");
+ var params = attr.hnswIndexParams();
+ assertTrue(params.isPresent());
+ assertEquals(maxLinksPerNode, params.get().maxLinksPerNode());
+ assertEquals(neighborsToExploreAtInsert, params.get().neighborsToExploreAtInsert());
+ }
+
+ private String getSdWithIndexSpec(String indexSpec) {
+ return getSd(joinLines("field t1 type tensor(x[64]) {",
+ " indexing: attribute | index",
+ " " + indexSpec,
+ "}"));
+ }
+
+ private void assertStartsWith(String prefix, String string) {
+ assertEquals(prefix, string.substring(0, Math.min(prefix.length(), string.length())));
+ }
+
+}