From 9becb15cfa3381d362831077badbf5607673db98 Mon Sep 17 00:00:00 2001 From: Geir Storli Date: Fri, 14 Apr 2023 13:16:24 +0000 Subject: Add vmsfields config for nearest neighbor search on supported tensor fields. --- .../java/com/yahoo/schema/derived/VsmFields.java | 74 ++++++++++++++-------- .../schema/processing/TensorFieldProcessor.java | 9 +++ .../application/validation/StreamingValidator.java | 3 + .../test/derived/nearestneighbor_streaming/test.sd | 24 +++++++ .../nearestneighbor_streaming/vsmfields.cfg | 31 +++++++++ .../schema/derived/NearestNeighborTestCase.java | 5 ++ 6 files changed, 120 insertions(+), 26 deletions(-) create mode 100644 config-model/src/test/derived/nearestneighbor_streaming/test.sd create mode 100644 config-model/src/test/derived/nearestneighbor_streaming/vsmfields.cfg (limited to 'config-model/src') diff --git a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java index c8679b6166c..c032a7155b2 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java @@ -13,12 +13,14 @@ import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.schema.FieldSets; import com.yahoo.schema.Schema; +import com.yahoo.schema.document.Attribute; import com.yahoo.schema.document.FieldSet; import com.yahoo.schema.document.GeoPos; import com.yahoo.schema.document.Matching; import com.yahoo.schema.document.MatchType; import com.yahoo.schema.document.SDDocumentType; import com.yahoo.schema.document.SDField; +import com.yahoo.schema.processing.TensorFieldProcessor; import com.yahoo.vespa.config.search.vsm.VsmfieldsConfig; import java.util.LinkedHashMap; @@ -124,63 +126,68 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { private final Type type; private final boolean isAttribute; + private final Attribute.DistanceMetric distanceMetric; /** The streaming field type enumeration */ public static class Type { - public static Type INT8 = new Type("int8","INT8"); - public static Type INT16 = new Type("int16","INT16"); - public static Type INT32 = new Type("int32","INT32"); - public static Type INT64 = new Type("int64","INT64"); - public static Type FLOAT16 = new Type("float16", "FLOAT16"); - public static Type FLOAT = new Type("float","FLOAT"); - public static Type DOUBLE = new Type("double","DOUBLE"); - public static Type STRING = new Type("string","AUTOUTF8"); - public static Type BOOL = new Type("bool","BOOL"); - public static Type UNSEARCHABLESTRING = new Type("string","NONE"); - public static Type GEO_POSITION = new Type("position", "GEOPOS"); - - private String name; + public static Type INT8 = new Type("INT8"); + public static Type INT16 = new Type("INT16"); + public static Type INT32 = new Type("INT32"); + public static Type INT64 = new Type("INT64"); + public static Type FLOAT16 = new Type("FLOAT16"); + public static Type FLOAT = new Type("FLOAT"); + public static Type DOUBLE = new Type("DOUBLE"); + public static Type STRING = new Type("AUTOUTF8"); + public static Type BOOL = new Type("BOOL"); + public static Type UNSEARCHABLESTRING = new Type("NONE"); + public static Type GEO_POSITION = new Type("GEOPOS"); + public static Type NEAREST_NEIGHBOR = new Type("NEAREST_NEIGHBOR"); private String searchMethod; - private Type(String name, String searchMethod) { - this.name = name; + private Type(String searchMethod) { this.searchMethod = searchMethod; } @Override public int hashCode() { - return name.hashCode(); + return searchMethod.hashCode(); } - /** Returns the name of this type */ - public String getName() { return name; } - public String getSearchMethod() { return searchMethod; } @Override public boolean equals(Object other) { if ( ! (other instanceof Type)) return false; - return this.name.equals(((Type)other).name); + return this.searchMethod.equals(((Type)other).searchMethod); } @Override public String toString() { - return "type: " + name; + return "method: " + searchMethod; } } public StreamingField(SDField field) { - this(field.getName(), field.getDataType(), field.getMatching(), field.doesAttributing()); + this(field.getName(), field.getDataType(), field.getMatching(), field.doesAttributing(), getDistanceMetric(field)); } - private StreamingField(String name, DataType sourceType, Matching matching, boolean isAttribute) { + private StreamingField(String name, DataType sourceType, Matching matching, boolean isAttribute, Attribute.DistanceMetric distanceMetric) { this.name = name; this.type = convertType(sourceType); this.matching = matching; this.isAttribute = isAttribute; + this.distanceMetric = distanceMetric; + } + + private static Attribute.DistanceMetric getDistanceMetric(SDField field) { + var attr = field.getAttribute(); + if (attr != null) { + return attr.distanceMetric(); + } + return Attribute.DEFAULT_DISTANCE_METRIC; } /** Converts to the right index type from a field datatype */ @@ -211,6 +218,10 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { } else if (fval instanceof PredicateFieldValue) { return Type.UNSEARCHABLESTRING; } else if (fval instanceof TensorFieldValue) { + var tensorType = ((TensorFieldValue) fval).getDataType().getTensorType(); + if (TensorFieldProcessor.isTensorTypeThatSupportsHnswIndex(tensorType)) { + return Type.NEAREST_NEIGHBOR; + } return Type.UNSEARCHABLESTRING; } else if (fieldType instanceof CollectionDataType) { return convertType(((CollectionDataType) fieldType).getNestedType()); @@ -224,8 +235,7 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { public String getName() { return name; } - public VsmfieldsConfig.Fieldspec.Builder getFieldSpecConfig() { - VsmfieldsConfig.Fieldspec.Builder fB = new VsmfieldsConfig.Fieldspec.Builder(); + public String getMatchingName() { String matchingName = matching.getType().getName(); if (matching.getType().equals(MatchType.TEXT)) matchingName = ""; @@ -241,9 +251,21 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { if (type != Type.STRING) { matchingName = ""; } + return matchingName; + } + + public String getArg1() { + if (type == Type.NEAREST_NEIGHBOR) { + return distanceMetric.name(); + } + return getMatchingName(); + } + + public VsmfieldsConfig.Fieldspec.Builder getFieldSpecConfig() { + var fB = new VsmfieldsConfig.Fieldspec.Builder(); fB.name(getName()) .searchmethod(VsmfieldsConfig.Fieldspec.Searchmethod.Enum.valueOf(type.getSearchMethod())) - .arg1(matchingName) + .arg1(getArg1()) .fieldtype(isAttribute ? VsmfieldsConfig.Fieldspec.Fieldtype.ATTRIBUTE : VsmfieldsConfig.Fieldspec.Fieldtype.INDEX); diff --git a/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java b/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java index 37da07f8227..227054d9800 100644 --- a/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java +++ b/config-model/src/main/java/com/yahoo/schema/processing/TensorFieldProcessor.java @@ -9,6 +9,7 @@ import com.yahoo.schema.Schema; import com.yahoo.schema.document.HnswIndexParams; import com.yahoo.schema.document.ImmutableSDField; import com.yahoo.schema.document.SDField; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.container.search.QueryProfiles; /** @@ -50,6 +51,14 @@ public class TensorFieldProcessor extends Processor { private boolean isTensorTypeThatSupportsHnswIndex(ImmutableSDField field) { var type = ((TensorDataType)field.getDataType()).getTensorType(); + return isTensorTypeThatSupportsHnswIndex(type); + } + + /** + * Returns whether the given tensor type supports using HNSW index and + * nearest neighbor search. + */ + public static boolean isTensorTypeThatSupportsHnswIndex(TensorType type) { // Tensors with 1 indexed dimension support hnsw index (used for approximate nearest neighbor search). if ((type.dimensions().size() == 1) && type.dimensions().get(0).isIndexed()) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/StreamingValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/StreamingValidator.java index 773d696f3e8..ad126cfa22b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/StreamingValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/StreamingValidator.java @@ -5,6 +5,7 @@ import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.document.DataType; import com.yahoo.document.NumericDataType; +import com.yahoo.document.TensorDataType; import com.yahoo.documentmodel.NewDocumentReferenceDataType; import com.yahoo.schema.document.Attribute; import com.yahoo.schema.document.ImmutableSDField; @@ -63,6 +64,8 @@ public class StreamingValidator extends Validator { // If the field is numeric, we can't print this, because we may have converted the field to // attribute indexing ourselves (IntegerIndex2Attribute) if (sd.getDataType() instanceof NumericDataType) return; + // Tensor fields are only searchable via nearest neighbor search, and match semantics are irrelevant. + if (sd.getDataType() instanceof TensorDataType) return; logger.logApplicationPackage(Level.WARNING, "For streaming search cluster '" + sc.getClusterName() + "', SD field '" + sd.getName() + "': 'attribute' has same match semantics as 'index'."); diff --git a/config-model/src/test/derived/nearestneighbor_streaming/test.sd b/config-model/src/test/derived/nearestneighbor_streaming/test.sd new file mode 100644 index 00000000000..4427fa08ab6 --- /dev/null +++ b/config-model/src/test/derived/nearestneighbor_streaming/test.sd @@ -0,0 +1,24 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +schema test { + document test { + field vec_a type tensor(x[16]) { + indexing: attribute + } + field vec_b type tensor(x[16]) { + indexing: attribute + attribute { + distance-metric: angular + } + } + field vec_c type tensor(m{},x[16]) { + indexing: attribute + attribute { + distance-metric: innerproduct + } + } + # This tensor field can not be used with nearest neighbor search. + field vec_d type tensor(x{}) { + indexing: attribute + } + } +} diff --git a/config-model/src/test/derived/nearestneighbor_streaming/vsmfields.cfg b/config-model/src/test/derived/nearestneighbor_streaming/vsmfields.cfg new file mode 100644 index 00000000000..f8b1cf62048 --- /dev/null +++ b/config-model/src/test/derived/nearestneighbor_streaming/vsmfields.cfg @@ -0,0 +1,31 @@ +documentverificationlevel 0 +searchall 1 +fieldspec[].name "vec_a" +fieldspec[].searchmethod NEAREST_NEIGHBOR +fieldspec[].arg1 "EUCLIDEAN" +fieldspec[].maxlength 1048576 +fieldspec[].fieldtype ATTRIBUTE +fieldspec[].name "vec_b" +fieldspec[].searchmethod NEAREST_NEIGHBOR +fieldspec[].arg1 "ANGULAR" +fieldspec[].maxlength 1048576 +fieldspec[].fieldtype ATTRIBUTE +fieldspec[].name "vec_c" +fieldspec[].searchmethod NEAREST_NEIGHBOR +fieldspec[].arg1 "INNERPRODUCT" +fieldspec[].maxlength 1048576 +fieldspec[].fieldtype ATTRIBUTE +fieldspec[].name "vec_d" +fieldspec[].searchmethod NONE +fieldspec[].arg1 "" +fieldspec[].maxlength 1048576 +fieldspec[].fieldtype ATTRIBUTE +documenttype[].name "test" +documenttype[].index[].name "vec_a" +documenttype[].index[].field[].name "vec_a" +documenttype[].index[].name "vec_b" +documenttype[].index[].field[].name "vec_b" +documenttype[].index[].name "vec_c" +documenttype[].index[].field[].name "vec_c" +documenttype[].index[].name "vec_d" +documenttype[].index[].field[].name "vec_d" diff --git a/config-model/src/test/java/com/yahoo/schema/derived/NearestNeighborTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/NearestNeighborTestCase.java index b3a0b8d4558..713da6f5cbe 100644 --- a/config-model/src/test/java/com/yahoo/schema/derived/NearestNeighborTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/derived/NearestNeighborTestCase.java @@ -35,4 +35,9 @@ public class NearestNeighborTestCase extends AbstractExportingTestCase { } } + @Test + void test_nearest_neighbor_streaming() throws IOException, ParseException { + assertCorrectDeriving("nearestneighbor_streaming"); + } + } -- cgit v1.2.3