aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-04-14 13:16:24 +0000
committerGeir Storli <geirst@yahooinc.com>2023-04-14 13:16:24 +0000
commit9becb15cfa3381d362831077badbf5607673db98 (patch)
tree1915af6b5eb339910f8ca32c7ba1d20088c4c0f3 /config-model
parentb66f35888ad413800ac16f841582da5bf067cb7f (diff)
Add vmsfields config for nearest neighbor search on supported tensor fields.
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java74
-rw-r--r--config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/StreamingValidator.java3
-rw-r--r--config-model/src/test/derived/nearestneighbor_streaming/test.sd24
-rw-r--r--config-model/src/test/derived/nearestneighbor_streaming/vsmfields.cfg31
-rw-r--r--config-model/src/test/java/com/yahoo/schema/derived/NearestNeighborTestCase.java5
6 files changed, 120 insertions, 26 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java
index c8679b6166c..c032a7155b2 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java
@@ -13,12 +13,14 @@ import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.schema.FieldSets;
import com.yahoo.schema.Schema;
+import com.yahoo.schema.document.Attribute;
import com.yahoo.schema.document.FieldSet;
import com.yahoo.schema.document.GeoPos;
import com.yahoo.schema.document.Matching;
import com.yahoo.schema.document.MatchType;
import com.yahoo.schema.document.SDDocumentType;
import com.yahoo.schema.document.SDField;
+import com.yahoo.schema.processing.TensorFieldProcessor;
import com.yahoo.vespa.config.search.vsm.VsmfieldsConfig;
import java.util.LinkedHashMap;
@@ -124,63 +126,68 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer {
private final Type type;
private final boolean isAttribute;
+ private final Attribute.DistanceMetric distanceMetric;
/** The streaming field type enumeration */
public static class Type {
- public static Type INT8 = new Type("int8","INT8");
- public static Type INT16 = new Type("int16","INT16");
- public static Type INT32 = new Type("int32","INT32");
- public static Type INT64 = new Type("int64","INT64");
- public static Type FLOAT16 = new Type("float16", "FLOAT16");
- public static Type FLOAT = new Type("float","FLOAT");
- public static Type DOUBLE = new Type("double","DOUBLE");
- public static Type STRING = new Type("string","AUTOUTF8");
- public static Type BOOL = new Type("bool","BOOL");
- public static Type UNSEARCHABLESTRING = new Type("string","NONE");
- public static Type GEO_POSITION = new Type("position", "GEOPOS");
-
- private String name;
+ public static Type INT8 = new Type("INT8");
+ public static Type INT16 = new Type("INT16");
+ public static Type INT32 = new Type("INT32");
+ public static Type INT64 = new Type("INT64");
+ public static Type FLOAT16 = new Type("FLOAT16");
+ public static Type FLOAT = new Type("FLOAT");
+ public static Type DOUBLE = new Type("DOUBLE");
+ public static Type STRING = new Type("AUTOUTF8");
+ public static Type BOOL = new Type("BOOL");
+ public static Type UNSEARCHABLESTRING = new Type("NONE");
+ public static Type GEO_POSITION = new Type("GEOPOS");
+ public static Type NEAREST_NEIGHBOR = new Type("NEAREST_NEIGHBOR");
private String searchMethod;
- private Type(String name, String searchMethod) {
- this.name = name;
+ private Type(String searchMethod) {
this.searchMethod = searchMethod;
}
@Override
public int hashCode() {
- return name.hashCode();
+ return searchMethod.hashCode();
}
- /** Returns the name of this type */
- public String getName() { return name; }
-
public String getSearchMethod() { return searchMethod; }
@Override
public boolean equals(Object other) {
if ( ! (other instanceof Type)) return false;
- return this.name.equals(((Type)other).name);
+ return this.searchMethod.equals(((Type)other).searchMethod);
}
@Override
public String toString() {
- return "type: " + name;
+ return "method: " + searchMethod;
}
}
public StreamingField(SDField field) {
- this(field.getName(), field.getDataType(), field.getMatching(), field.doesAttributing());
+ this(field.getName(), field.getDataType(), field.getMatching(), field.doesAttributing(), getDistanceMetric(field));
}
- private StreamingField(String name, DataType sourceType, Matching matching, boolean isAttribute) {
+ private StreamingField(String name, DataType sourceType, Matching matching, boolean isAttribute, Attribute.DistanceMetric distanceMetric) {
this.name = name;
this.type = convertType(sourceType);
this.matching = matching;
this.isAttribute = isAttribute;
+ this.distanceMetric = distanceMetric;
+ }
+
+ private static Attribute.DistanceMetric getDistanceMetric(SDField field) {
+ var attr = field.getAttribute();
+ if (attr != null) {
+ return attr.distanceMetric();
+ }
+ return Attribute.DEFAULT_DISTANCE_METRIC;
}
/** Converts to the right index type from a field datatype */
@@ -211,6 +218,10 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer {
} else if (fval instanceof PredicateFieldValue) {
return Type.UNSEARCHABLESTRING;
} else if (fval instanceof TensorFieldValue) {
+ var tensorType = ((TensorFieldValue) fval).getDataType().getTensorType();
+ if (TensorFieldProcessor.isTensorTypeThatSupportsHnswIndex(tensorType)) {
+ return Type.NEAREST_NEIGHBOR;
+ }
return Type.UNSEARCHABLESTRING;
} else if (fieldType instanceof CollectionDataType) {
return convertType(((CollectionDataType) fieldType).getNestedType());
@@ -224,8 +235,7 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer {
public String getName() { return name; }
- public VsmfieldsConfig.Fieldspec.Builder getFieldSpecConfig() {
- VsmfieldsConfig.Fieldspec.Builder fB = new VsmfieldsConfig.Fieldspec.Builder();
+ public String getMatchingName() {
String matchingName = matching.getType().getName();
if (matching.getType().equals(MatchType.TEXT))
matchingName = "";
@@ -241,9 +251,21 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer {
if (type != Type.STRING) {
matchingName = "";
}
+ return matchingName;
+ }
+
+ public String getArg1() {
+ if (type == Type.NEAREST_NEIGHBOR) {
+ return distanceMetric.name();
+ }
+ return getMatchingName();
+ }
+
+ public VsmfieldsConfig.Fieldspec.Builder getFieldSpecConfig() {
+ var fB = new VsmfieldsConfig.Fieldspec.Builder();
fB.name(getName())
.searchmethod(VsmfieldsConfig.Fieldspec.Searchmethod.Enum.valueOf(type.getSearchMethod()))
- .arg1(matchingName)
+ .arg1(getArg1())
.fieldtype(isAttribute
? VsmfieldsConfig.Fieldspec.Fieldtype.ATTRIBUTE
: VsmfieldsConfig.Fieldspec.Fieldtype.INDEX);
diff --git a/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java b/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java
index 37da07f8227..227054d9800 100644
--- a/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java
+++ b/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java
@@ -9,6 +9,7 @@ import com.yahoo.schema.Schema;
import com.yahoo.schema.document.HnswIndexParams;
import com.yahoo.schema.document.ImmutableSDField;
import com.yahoo.schema.document.SDField;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.container.search.QueryProfiles;
/**
@@ -50,6 +51,14 @@ public class TensorFieldProcessor extends Processor {
private boolean isTensorTypeThatSupportsHnswIndex(ImmutableSDField field) {
var type = ((TensorDataType)field.getDataType()).getTensorType();
+ return isTensorTypeThatSupportsHnswIndex(type);
+ }
+
+ /**
+ * Returns whether the given tensor type supports using HNSW index and
+ * nearest neighbor search.
+ */
+ public static boolean isTensorTypeThatSupportsHnswIndex(TensorType type) {
// Tensors with 1 indexed dimension support hnsw index (used for approximate nearest neighbor search).
if ((type.dimensions().size() == 1) &&
type.dimensions().get(0).isIndexed()) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/StreamingValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/StreamingValidator.java
index 773d696f3e8..ad126cfa22b 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/StreamingValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/StreamingValidator.java
@@ -5,6 +5,7 @@ import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.document.DataType;
import com.yahoo.document.NumericDataType;
+import com.yahoo.document.TensorDataType;
import com.yahoo.documentmodel.NewDocumentReferenceDataType;
import com.yahoo.schema.document.Attribute;
import com.yahoo.schema.document.ImmutableSDField;
@@ -63,6 +64,8 @@ public class StreamingValidator extends Validator {
// If the field is numeric, we can't print this, because we may have converted the field to
// attribute indexing ourselves (IntegerIndex2Attribute)
if (sd.getDataType() instanceof NumericDataType) return;
+ // Tensor fields are only searchable via nearest neighbor search, and match semantics are irrelevant.
+ if (sd.getDataType() instanceof TensorDataType) return;
logger.logApplicationPackage(Level.WARNING, "For streaming search cluster '" + sc.getClusterName() +
"', SD field '" + sd.getName() +
"': 'attribute' has same match semantics as 'index'.");
diff --git a/config-model/src/test/derived/nearestneighbor_streaming/test.sd b/config-model/src/test/derived/nearestneighbor_streaming/test.sd
new file mode 100644
index 00000000000..4427fa08ab6
--- /dev/null
+++ b/config-model/src/test/derived/nearestneighbor_streaming/test.sd
@@ -0,0 +1,24 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+schema test {
+ document test {
+ field vec_a type tensor<float>(x[16]) {
+ indexing: attribute
+ }
+ field vec_b type tensor<float>(x[16]) {
+ indexing: attribute
+ attribute {
+ distance-metric: angular
+ }
+ }
+ field vec_c type tensor<float>(m{},x[16]) {
+ indexing: attribute
+ attribute {
+ distance-metric: innerproduct
+ }
+ }
+ # This tensor field can not be used with nearest neighbor search.
+ field vec_d type tensor<float>(x{}) {
+ indexing: attribute
+ }
+ }
+}
diff --git a/config-model/src/test/derived/nearestneighbor_streaming/vsmfields.cfg b/config-model/src/test/derived/nearestneighbor_streaming/vsmfields.cfg
new file mode 100644
index 00000000000..f8b1cf62048
--- /dev/null
+++ b/config-model/src/test/derived/nearestneighbor_streaming/vsmfields.cfg
@@ -0,0 +1,31 @@
+documentverificationlevel 0
+searchall 1
+fieldspec[].name "vec_a"
+fieldspec[].searchmethod NEAREST_NEIGHBOR
+fieldspec[].arg1 "EUCLIDEAN"
+fieldspec[].maxlength 1048576
+fieldspec[].fieldtype ATTRIBUTE
+fieldspec[].name "vec_b"
+fieldspec[].searchmethod NEAREST_NEIGHBOR
+fieldspec[].arg1 "ANGULAR"
+fieldspec[].maxlength 1048576
+fieldspec[].fieldtype ATTRIBUTE
+fieldspec[].name "vec_c"
+fieldspec[].searchmethod NEAREST_NEIGHBOR
+fieldspec[].arg1 "INNERPRODUCT"
+fieldspec[].maxlength 1048576
+fieldspec[].fieldtype ATTRIBUTE
+fieldspec[].name "vec_d"
+fieldspec[].searchmethod NONE
+fieldspec[].arg1 ""
+fieldspec[].maxlength 1048576
+fieldspec[].fieldtype ATTRIBUTE
+documenttype[].name "test"
+documenttype[].index[].name "vec_a"
+documenttype[].index[].field[].name "vec_a"
+documenttype[].index[].name "vec_b"
+documenttype[].index[].field[].name "vec_b"
+documenttype[].index[].name "vec_c"
+documenttype[].index[].field[].name "vec_c"
+documenttype[].index[].name "vec_d"
+documenttype[].index[].field[].name "vec_d"
diff --git a/config-model/src/test/java/com/yahoo/schema/derived/NearestNeighborTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/NearestNeighborTestCase.java
index b3a0b8d4558..713da6f5cbe 100644
--- a/config-model/src/test/java/com/yahoo/schema/derived/NearestNeighborTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/derived/NearestNeighborTestCase.java
@@ -35,4 +35,9 @@ public class NearestNeighborTestCase extends AbstractExportingTestCase {
}
}
+ @Test
+ void test_nearest_neighbor_streaming() throws IOException, ParseException {
+ assertCorrectDeriving("nearestneighbor_streaming");
+ }
+
}