diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-04-14 17:21:26 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-14 17:21:26 +0200 |
commit | 0d7e4a262ed8f1dbbba19f44d0c312c3bc461123 (patch) | |
tree | 0fa80aa9a62d5cab86ad87b359925887bb17f3d2 /config-model/src/main/java/com/yahoo/schema | |
parent | 2087922d2af283d000729370932fd765c196f5a2 (diff) | |
parent | 9becb15cfa3381d362831077badbf5607673db98 (diff) |
Merge pull request #26742 from vespa-engine/geirst/vsmfields-nearest-neighbor-search
Add vmsfields config for nearest neighbor search on supported tensor …
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java | 74 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java | 9 |
2 files changed, 57 insertions, 26 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java index c8679b6166c..c032a7155b2 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java @@ -13,12 +13,14 @@ import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.schema.FieldSets; import com.yahoo.schema.Schema; +import com.yahoo.schema.document.Attribute; import com.yahoo.schema.document.FieldSet; import com.yahoo.schema.document.GeoPos; import com.yahoo.schema.document.Matching; import com.yahoo.schema.document.MatchType; import com.yahoo.schema.document.SDDocumentType; import com.yahoo.schema.document.SDField; +import com.yahoo.schema.processing.TensorFieldProcessor; import com.yahoo.vespa.config.search.vsm.VsmfieldsConfig; import java.util.LinkedHashMap; @@ -124,63 +126,68 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { private final Type type; private final boolean isAttribute; + private final Attribute.DistanceMetric distanceMetric; /** The streaming field type enumeration */ public static class Type { - public static Type INT8 = new Type("int8","INT8"); - public static Type INT16 = new Type("int16","INT16"); - public static Type INT32 = new Type("int32","INT32"); - public static Type INT64 = new Type("int64","INT64"); - public static Type FLOAT16 = new Type("float16", "FLOAT16"); - public static Type FLOAT = new Type("float","FLOAT"); - public static Type DOUBLE = new Type("double","DOUBLE"); - public static Type STRING = new Type("string","AUTOUTF8"); - public static Type BOOL = new Type("bool","BOOL"); - public static Type UNSEARCHABLESTRING = new Type("string","NONE"); - public static Type GEO_POSITION = new Type("position", "GEOPOS"); - - private String name; + public static Type INT8 = new Type("INT8"); + public static Type INT16 = new Type("INT16"); + public static Type INT32 = new Type("INT32"); + public static Type INT64 = new Type("INT64"); + public static Type FLOAT16 = new Type("FLOAT16"); + public static Type FLOAT = new Type("FLOAT"); + public static Type DOUBLE = new Type("DOUBLE"); + public static Type STRING = new Type("AUTOUTF8"); + public static Type BOOL = new Type("BOOL"); + public static Type UNSEARCHABLESTRING = new Type("NONE"); + public static Type GEO_POSITION = new Type("GEOPOS"); + public static Type NEAREST_NEIGHBOR = new Type("NEAREST_NEIGHBOR"); private String searchMethod; - private Type(String name, String searchMethod) { - this.name = name; + private Type(String searchMethod) { this.searchMethod = searchMethod; } @Override public int hashCode() { - return name.hashCode(); + return searchMethod.hashCode(); } - /** Returns the name of this type */ - public String getName() { return name; } - public String getSearchMethod() { return searchMethod; } @Override public boolean equals(Object other) { if ( ! (other instanceof Type)) return false; - return this.name.equals(((Type)other).name); + return this.searchMethod.equals(((Type)other).searchMethod); } @Override public String toString() { - return "type: " + name; + return "method: " + searchMethod; } } public StreamingField(SDField field) { - this(field.getName(), field.getDataType(), field.getMatching(), field.doesAttributing()); + this(field.getName(), field.getDataType(), field.getMatching(), field.doesAttributing(), getDistanceMetric(field)); } - private StreamingField(String name, DataType sourceType, Matching matching, boolean isAttribute) { + private StreamingField(String name, DataType sourceType, Matching matching, boolean isAttribute, Attribute.DistanceMetric distanceMetric) { this.name = name; this.type = convertType(sourceType); this.matching = matching; this.isAttribute = isAttribute; + this.distanceMetric = distanceMetric; + } + + private static Attribute.DistanceMetric getDistanceMetric(SDField field) { + var attr = field.getAttribute(); + if (attr != null) { + return attr.distanceMetric(); + } + return Attribute.DEFAULT_DISTANCE_METRIC; } /** Converts to the right index type from a field datatype */ @@ -211,6 +218,10 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { } else if (fval instanceof PredicateFieldValue) { return Type.UNSEARCHABLESTRING; } else if (fval instanceof TensorFieldValue) { + var tensorType = ((TensorFieldValue) fval).getDataType().getTensorType(); + if (TensorFieldProcessor.isTensorTypeThatSupportsHnswIndex(tensorType)) { + return Type.NEAREST_NEIGHBOR; + } return Type.UNSEARCHABLESTRING; } else if (fieldType instanceof CollectionDataType) { return convertType(((CollectionDataType) fieldType).getNestedType()); @@ -224,8 +235,7 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { public String getName() { return name; } - public VsmfieldsConfig.Fieldspec.Builder getFieldSpecConfig() { - VsmfieldsConfig.Fieldspec.Builder fB = new VsmfieldsConfig.Fieldspec.Builder(); + public String getMatchingName() { String matchingName = matching.getType().getName(); if (matching.getType().equals(MatchType.TEXT)) matchingName = ""; @@ -241,9 +251,21 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { if (type != Type.STRING) { matchingName = ""; } + return matchingName; + } + + public String getArg1() { + if (type == Type.NEAREST_NEIGHBOR) { + return distanceMetric.name(); + } + return getMatchingName(); + } + + public VsmfieldsConfig.Fieldspec.Builder getFieldSpecConfig() { + var fB = new VsmfieldsConfig.Fieldspec.Builder(); fB.name(getName()) .searchmethod(VsmfieldsConfig.Fieldspec.Searchmethod.Enum.valueOf(type.getSearchMethod())) - .arg1(matchingName) + .arg1(getArg1()) .fieldtype(isAttribute ? VsmfieldsConfig.Fieldspec.Fieldtype.ATTRIBUTE : VsmfieldsConfig.Fieldspec.Fieldtype.INDEX); diff --git a/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java b/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java index 37da07f8227..227054d9800 100644 --- a/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java +++ b/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java @@ -9,6 +9,7 @@ import com.yahoo.schema.Schema; import com.yahoo.schema.document.HnswIndexParams; import com.yahoo.schema.document.ImmutableSDField; import com.yahoo.schema.document.SDField; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.container.search.QueryProfiles; /** @@ -50,6 +51,14 @@ public class TensorFieldProcessor extends Processor { private boolean isTensorTypeThatSupportsHnswIndex(ImmutableSDField field) { var type = ((TensorDataType)field.getDataType()).getTensorType(); + return isTensorTypeThatSupportsHnswIndex(type); + } + + /** + * Returns whether the given tensor type supports using HNSW index and + * nearest neighbor search. + */ + public static boolean isTensorTypeThatSupportsHnswIndex(TensorType type) { // Tensors with 1 indexed dimension support hnsw index (used for approximate nearest neighbor search). if ((type.dimensions().size() == 1) && type.dimensions().get(0).isIndexed()) { |