aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-04-14 17:21:26 +0200
committerGitHub <noreply@github.com>2023-04-14 17:21:26 +0200
commit0d7e4a262ed8f1dbbba19f44d0c312c3bc461123 (patch)
tree0fa80aa9a62d5cab86ad87b359925887bb17f3d2 /config-model/src/main/java/com/yahoo/schema
parent2087922d2af283d000729370932fd765c196f5a2 (diff)
parent9becb15cfa3381d362831077badbf5607673db98 (diff)
Merge pull request #26742 from vespa-engine/geirst/vsmfields-nearest-neighbor-search
Add vmsfields config for nearest neighbor search on supported tensor …
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java74
-rw-r--r--config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java9
2 files changed, 57 insertions, 26 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java
index c8679b6166c..c032a7155b2 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java
@@ -13,12 +13,14 @@ import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.schema.FieldSets;
import com.yahoo.schema.Schema;
+import com.yahoo.schema.document.Attribute;
import com.yahoo.schema.document.FieldSet;
import com.yahoo.schema.document.GeoPos;
import com.yahoo.schema.document.Matching;
import com.yahoo.schema.document.MatchType;
import com.yahoo.schema.document.SDDocumentType;
import com.yahoo.schema.document.SDField;
+import com.yahoo.schema.processing.TensorFieldProcessor;
import com.yahoo.vespa.config.search.vsm.VsmfieldsConfig;
import java.util.LinkedHashMap;
@@ -124,63 +126,68 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer {
private final Type type;
private final boolean isAttribute;
+ private final Attribute.DistanceMetric distanceMetric;
/** The streaming field type enumeration */
public static class Type {
- public static Type INT8 = new Type("int8","INT8");
- public static Type INT16 = new Type("int16","INT16");
- public static Type INT32 = new Type("int32","INT32");
- public static Type INT64 = new Type("int64","INT64");
- public static Type FLOAT16 = new Type("float16", "FLOAT16");
- public static Type FLOAT = new Type("float","FLOAT");
- public static Type DOUBLE = new Type("double","DOUBLE");
- public static Type STRING = new Type("string","AUTOUTF8");
- public static Type BOOL = new Type("bool","BOOL");
- public static Type UNSEARCHABLESTRING = new Type("string","NONE");
- public static Type GEO_POSITION = new Type("position", "GEOPOS");
-
- private String name;
+ public static Type INT8 = new Type("INT8");
+ public static Type INT16 = new Type("INT16");
+ public static Type INT32 = new Type("INT32");
+ public static Type INT64 = new Type("INT64");
+ public static Type FLOAT16 = new Type("FLOAT16");
+ public static Type FLOAT = new Type("FLOAT");
+ public static Type DOUBLE = new Type("DOUBLE");
+ public static Type STRING = new Type("AUTOUTF8");
+ public static Type BOOL = new Type("BOOL");
+ public static Type UNSEARCHABLESTRING = new Type("NONE");
+ public static Type GEO_POSITION = new Type("GEOPOS");
+ public static Type NEAREST_NEIGHBOR = new Type("NEAREST_NEIGHBOR");
private String searchMethod;
- private Type(String name, String searchMethod) {
- this.name = name;
+ private Type(String searchMethod) {
this.searchMethod = searchMethod;
}
@Override
public int hashCode() {
- return name.hashCode();
+ return searchMethod.hashCode();
}
- /** Returns the name of this type */
- public String getName() { return name; }
-
public String getSearchMethod() { return searchMethod; }
@Override
public boolean equals(Object other) {
if ( ! (other instanceof Type)) return false;
- return this.name.equals(((Type)other).name);
+ return this.searchMethod.equals(((Type)other).searchMethod);
}
@Override
public String toString() {
- return "type: " + name;
+ return "method: " + searchMethod;
}
}
public StreamingField(SDField field) {
- this(field.getName(), field.getDataType(), field.getMatching(), field.doesAttributing());
+ this(field.getName(), field.getDataType(), field.getMatching(), field.doesAttributing(), getDistanceMetric(field));
}
- private StreamingField(String name, DataType sourceType, Matching matching, boolean isAttribute) {
+ private StreamingField(String name, DataType sourceType, Matching matching, boolean isAttribute, Attribute.DistanceMetric distanceMetric) {
this.name = name;
this.type = convertType(sourceType);
this.matching = matching;
this.isAttribute = isAttribute;
+ this.distanceMetric = distanceMetric;
+ }
+
+ private static Attribute.DistanceMetric getDistanceMetric(SDField field) {
+ var attr = field.getAttribute();
+ if (attr != null) {
+ return attr.distanceMetric();
+ }
+ return Attribute.DEFAULT_DISTANCE_METRIC;
}
/** Converts to the right index type from a field datatype */
@@ -211,6 +218,10 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer {
} else if (fval instanceof PredicateFieldValue) {
return Type.UNSEARCHABLESTRING;
} else if (fval instanceof TensorFieldValue) {
+ var tensorType = ((TensorFieldValue) fval).getDataType().getTensorType();
+ if (TensorFieldProcessor.isTensorTypeThatSupportsHnswIndex(tensorType)) {
+ return Type.NEAREST_NEIGHBOR;
+ }
return Type.UNSEARCHABLESTRING;
} else if (fieldType instanceof CollectionDataType) {
return convertType(((CollectionDataType) fieldType).getNestedType());
@@ -224,8 +235,7 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer {
public String getName() { return name; }
- public VsmfieldsConfig.Fieldspec.Builder getFieldSpecConfig() {
- VsmfieldsConfig.Fieldspec.Builder fB = new VsmfieldsConfig.Fieldspec.Builder();
+ public String getMatchingName() {
String matchingName = matching.getType().getName();
if (matching.getType().equals(MatchType.TEXT))
matchingName = "";
@@ -241,9 +251,21 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer {
if (type != Type.STRING) {
matchingName = "";
}
+ return matchingName;
+ }
+
+ public String getArg1() {
+ if (type == Type.NEAREST_NEIGHBOR) {
+ return distanceMetric.name();
+ }
+ return getMatchingName();
+ }
+
+ public VsmfieldsConfig.Fieldspec.Builder getFieldSpecConfig() {
+ var fB = new VsmfieldsConfig.Fieldspec.Builder();
fB.name(getName())
.searchmethod(VsmfieldsConfig.Fieldspec.Searchmethod.Enum.valueOf(type.getSearchMethod()))
- .arg1(matchingName)
+ .arg1(getArg1())
.fieldtype(isAttribute
? VsmfieldsConfig.Fieldspec.Fieldtype.ATTRIBUTE
: VsmfieldsConfig.Fieldspec.Fieldtype.INDEX);
diff --git a/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java b/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java
index 37da07f8227..227054d9800 100644
--- a/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java
+++ b/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java
@@ -9,6 +9,7 @@ import com.yahoo.schema.Schema;
import com.yahoo.schema.document.HnswIndexParams;
import com.yahoo.schema.document.ImmutableSDField;
import com.yahoo.schema.document.SDField;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.container.search.QueryProfiles;
/**
@@ -50,6 +51,14 @@ public class TensorFieldProcessor extends Processor {
private boolean isTensorTypeThatSupportsHnswIndex(ImmutableSDField field) {
var type = ((TensorDataType)field.getDataType()).getTensorType();
+ return isTensorTypeThatSupportsHnswIndex(type);
+ }
+
+ /**
+ * Returns whether the given tensor type supports using HNSW index and
+ * nearest neighbor search.
+ */
+ public static boolean isTensorTypeThatSupportsHnswIndex(TensorType type) {
// Tensors with 1 indexed dimension support hnsw index (used for approximate nearest neighbor search).
if ((type.dimensions().size() == 1) &&
type.dimensions().get(0).isIndexed()) {